Skip to content

Commit

Permalink
Merge pull request #112 from CDU-data-science-team/0.5.4
Browse files Browse the repository at this point in the history
0.5.4
  • Loading branch information
yiwen-h committed Jun 15, 2023
2 parents 8cad434 + 842271f commit 0b65e33
Show file tree
Hide file tree
Showing 26 changed files with 2,449 additions and 2,240 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ test_multilabel/*
.env
api/rsconnect-python/*
.coverage
api/bert*
162 changes: 149 additions & 13 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,71 @@
import os
import pickle
from typing import List
from typing import List, Union

import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel, validator
from tensorflow.keras.saving import load_model

from pxtextmining.factories.factory_predict_unlabelled_text import (
predict_multilabel_sklearn,
predict_sentiment_bert,
)

minor_cats_v5 = [
"Gratitude/ good experience",
"Negative experience",
"Not assigned",
"Organisation & efficiency",
"Funding & use of financial resources",
"Non-specific praise for staff",
"Non-specific dissatisfaction with staff",
"Staff manner & personal attributes",
"Number & deployment of staff",
"Staff responsiveness",
"Staff continuity",
"Competence & training",
"Unspecified communication",
"Staff listening, understanding & involving patients",
"Information directly from staff during care",
"Information provision & guidance",
"Being kept informed, clarity & consistency of information",
"Service involvement with family/ carers",
"Patient contact with family/ carers",
"Contacting services",
"Appointment arrangements",
"Appointment method",
"Timeliness of care",
"Pain management",
"Diagnosis & triage",
"Referals & continuity of care",
"Length of stay/ duration of care",
"Discharge",
"Care plans",
"Patient records",
"Links with non-NHS organisations",
"Cleanliness, tidiness & infection control",
"Safety & security",
"Provision of medical equipment",
"Service location",
"Transport to/ from services",
"Parking",
"Electronic entertainment",
"Feeling safe",
"Patient appearance & grooming",
"Mental Health Act",
"Equality, Diversity & Inclusion",
"Admission",
"Collecting patients feedback",
"Labelling not possible",
"Environment & Facilities",
"Supplying & understanding medication",
"Activities & access to fresh air",
"Food & drink provision & facilities",
"Sensory experience",
"Impact of treatment/ care",
]

description = """
This API is for classifying patient experience qualitative data,
utilising the models trained as part of the pxtextmining project.
Expand All @@ -18,21 +74,42 @@
tags_metadata = [
{"name": "index", "description": "Basic page to test if API is working."},
{
"name": "predict",
"name": "multilabel",
"description": "Generate multilabel predictions for given text.",
},
{
"name": "sentiment",
"description": "Generate predicted sentiment for given text.",
},
]


async def load_sentiment_model():
model_path = "bert_sentiment"
if not os.path.exists(model_path):
model_path = os.path.join("api", model_path)
loaded_model = load_model(model_path)
return loaded_model


async def get_sentiment_predictions(
text_to_predict, loaded_model, preprocess_text, additional_features
):
predictions = predict_sentiment_bert(
text_to_predict,
loaded_model,
preprocess_text=preprocess_text,
additional_features=additional_features,
)
return predictions


class Test(BaseModel):
test: str

class Config:
schema_extra = {
"example": {
"test": "Hello"
}
}
schema_extra = {"example": {"test": "Hello"}}


class ItemIn(BaseModel):
comment_id: str
Expand All @@ -57,7 +134,7 @@ def question_type_validation(cls, v):
return v


class ItemOut(BaseModel):
class MultilabelOut(BaseModel):
comment_id: str
comment_text: str
labels: list
Expand All @@ -72,6 +149,21 @@ class Config:
}


class SentimentOut(BaseModel):
comment_id: str
comment_text: str
sentiment: Union[int, str]

class Config:
schema_extra = {
"example": {
"comment_id": "01",
"comment_text": "Nurses were friendly. Parking was awful.",
"sentiment": 3,
}
}


app = FastAPI(
title="pxtextmining API",
description=description,
Expand All @@ -85,17 +177,19 @@ class Config:
"name": "MIT License",
"url": "https://github.com/CDU-data-science-team/pxtextmining/blob/main/LICENSE",
},
openapi_tags=tags_metadata
openapi_tags=tags_metadata,
)


@app.get("/", response_model=Test, tags=['index'])
@app.get("/", response_model=Test, tags=["index"])
def index():
return {"test": "Hello"}


@app.post("/predict_multilabel", response_model=List[ItemOut], tags=['predict'])
def predict(items: List[ItemIn]):
@app.post(
"/predict_multilabel", response_model=List[MultilabelOut], tags=["multilabel"]
)
async def predict_multilabel(items: List[ItemIn]):
"""Accepts comment ids, comment text and question type as JSON in a POST request. Makes predictions using trained SVC model.
Args:
Expand Down Expand Up @@ -128,7 +222,7 @@ def predict(items: List[ItemIn]):
with open(model_path, "rb") as model:
loaded_model = pickle.load(model)
preds_df = predict_multilabel_sklearn(
text_to_predict, loaded_model, additional_features=True
text_to_predict, loaded_model, labels=minor_cats_v5, additional_features=True
)
# Join predicted labels with received data
preds_df["comment_id"] = preds_df.index.astype(str)
Expand All @@ -141,3 +235,45 @@ def predict(items: List[ItemIn]):
orient="records"
)
return return_dict


@app.post("/predict_sentiment", response_model=List[SentimentOut], tags=["sentiment"])
async def predict_sentiment(items: List[ItemIn]):
"""Accepts comment ids, comment text and question type as JSON in a POST request. Makes predictions using trained Tensorflow Keras model.
Args:
items (List[ItemIn]): JSON list of dictionaries with the following compulsory keys:
- `comment_id` (str)
- `comment_text` (str)
- `question_type` (str)
The 'question_type' must be one of three values: 'nonspecific', 'what_good', and 'could_improve'.
For example, `[{'comment_id': '1', 'comment_text': 'Thank you', 'question_type': 'what_good'},
{'comment_id': '2', 'comment_text': 'Food was cold', 'question_type': 'could_improve'}]`
Returns:
(dict): Keys are: `comment_id`, `comment_text`, and predicted `labels`.
"""

# Process received data
loaded_model = await load_sentiment_model()
df = pd.DataFrame([i.dict() for i in items], dtype=str)
df_newindex = df.set_index("comment_id")
if df_newindex.index.duplicated().sum() != 0:
raise ValueError("comment_id must all be unique values")
df_newindex.index.rename("Comment ID", inplace=True)
text_to_predict = df_newindex[["comment_text", "question_type"]]
text_to_predict = text_to_predict.rename(
columns={"comment_text": "FFT answer", "question_type": "FFT_q_standardised"}
)
# Make predictions
preds_df = await get_sentiment_predictions(
text_to_predict, loaded_model, preprocess_text=False, additional_features=True
)
# Join predicted labels with received data
preds_df["comment_id"] = preds_df.index.astype(str)
merged = pd.merge(df, preds_df, how="left", on="comment_id")
merged["sentiment"] = merged["sentiment"].fillna("Labelling not possible")
return_dict = merged[["comment_id", "comment_text", "sentiment"]].to_dict(
orient="records"
)
return return_dict
43 changes: 22 additions & 21 deletions api/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11"
astunparse==1.6.3 ; python_version >= "3.8" and python_version < "3.11"
blis==0.7.9 ; python_version >= "3.8" and python_version < "3.11"
cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11"
cachetools==5.3.1 ; python_version >= "3.8" and python_version < "3.11"
catalogue==2.0.8 ; python_version >= "3.8" and python_version < "3.11"
certifi==2023.5.7 ; python_version >= "3.8" and python_version < "3.11"
charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "3.11"
click==8.1.3 ; python_version >= "3.8" and python_version < "3.11"
colorama==0.4.6 ; python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows" or sys_platform == "win32" and python_version >= "3.8" and python_version < "3.11"
confection==0.0.4 ; python_version >= "3.8" and python_version < "3.11"
contourpy==1.0.7 ; python_version >= "3.8" and python_version < "3.11"
contourpy==1.1.0 ; python_version >= "3.8" and python_version < "3.11"
cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11"
cymem==2.0.7 ; python_version >= "3.8" and python_version < "3.11"
filelock==3.12.0 ; python_version >= "3.8" and python_version < "3.11"
flatbuffers==23.5.9 ; python_version >= "3.8" and python_version < "3.11"
fonttools==4.39.4 ; python_version >= "3.8" and python_version < "3.11"
fsspec==2023.5.0 ; python_version >= "3.8" and python_version < "3.11"
filelock==3.12.2 ; python_version >= "3.8" and python_version < "3.11"
flatbuffers==23.5.26 ; python_version >= "3.8" and python_version < "3.11"
fonttools==4.40.0 ; python_version >= "3.8" and python_version < "3.11"
fsspec==2023.6.0 ; python_version >= "3.8" and python_version < "3.11"
gast==0.4.0 ; python_version >= "3.8" and python_version < "3.11"
google-auth-oauthlib==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
google-auth==2.18.1 ; python_version >= "3.8" and python_version < "3.11"
google-auth==2.17.3 ; python_version >= "3.8" and python_version < "3.11"
google-pasta==0.2.0 ; python_version >= "3.8" and python_version < "3.11"
grpcio==1.54.2 ; python_version >= "3.8" and python_version < "3.11"
h5py==3.8.0 ; python_version >= "3.8" and python_version < "3.11"
huggingface-hub==0.14.1 ; python_version >= "3.8" and python_version < "3.11"
huggingface-hub==0.15.1 ; python_version >= "3.8" and python_version < "3.11"
idna==3.4 ; python_version >= "3.8" and python_version < "3.11"
importlib-metadata==6.6.0 ; python_version >= "3.8" and python_version < "3.10"
importlib-resources==5.12.0 ; python_version >= "3.8" and python_version < "3.10"
jax==0.4.10 ; python_version >= "3.8" and python_version < "3.11"
jax==0.4.12 ; python_version >= "3.8" and python_version < "3.11"
jinja2==3.1.2 ; python_version >= "3.8" and python_version < "3.11"
joblib==1.2.0 ; python_version >= "3.8" and python_version < "3.11"
keras==2.12.0 ; python_version >= "3.8" and python_version < "3.11"
kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11"
langcodes==3.3.0 ; python_version >= "3.8" and python_version < "3.11"
libclang==16.0.0 ; python_version >= "3.8" and python_version < "3.11"
markdown==3.3.7 ; python_version >= "3.8" and python_version < "3.11"
markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11"
markupsafe==2.1.3 ; python_version >= "3.8" and python_version < "3.11"
matplotlib==3.7.1 ; python_version >= "3.8" and python_version < "3.11"
ml-dtypes==0.1.0 ; python_version >= "3.8" and python_version < "3.11"
ml-dtypes==0.2.0 ; python_version >= "3.8" and python_version < "3.11"
murmurhash==1.0.9 ; python_version >= "3.8" and python_version < "3.11"
numpy==1.23.5 ; python_version < "3.11" and python_version >= "3.8"
oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11"
Expand All @@ -45,18 +45,19 @@ pandas==1.5.3 ; python_version >= "3.8" and python_version < "3.11"
pathy==0.10.1 ; python_version >= "3.8" and python_version < "3.11"
pillow==9.5.0 ; python_version >= "3.8" and python_version < "3.11"
preshed==3.0.8 ; python_version >= "3.8" and python_version < "3.11"
protobuf==4.23.1 ; python_version >= "3.8" and python_version < "3.11"
protobuf==4.23.3 ; python_version >= "3.8" and python_version < "3.11"
pyasn1-modules==0.3.0 ; python_version >= "3.8" and python_version < "3.11"
pyasn1==0.5.0 ; python_version >= "3.8" and python_version < "3.11"
pydantic==1.10.8 ; python_version >= "3.8" and python_version < "3.11"
pydantic==1.10.9 ; python_version >= "3.8" and python_version < "3.11"
pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11"
python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11"
pytz==2023.3 ; python_version >= "3.8" and python_version < "3.11"
pyyaml==6.0 ; python_version >= "3.8" and python_version < "3.11"
regex==2023.5.5 ; python_version >= "3.8" and python_version < "3.11"
regex==2023.6.3 ; python_version >= "3.8" and python_version < "3.11"
requests-oauthlib==1.3.1 ; python_version >= "3.8" and python_version < "3.11"
requests==2.31.0 ; python_version >= "3.8" and python_version < "3.11"
rsa==4.9 ; python_version >= "3.8" and python_version < "3.11"
safetensors==0.3.1 ; python_version >= "3.8" and python_version < "3.11"
scikit-learn==1.0.2 ; python_version >= "3.8" and python_version < "3.11"
scipy==1.10.1 ; python_version >= "3.8" and python_version < "3.11"
setuptools-scm==7.1.0 ; python_version >= "3.8" and python_version < "3.11"
Expand All @@ -67,7 +68,7 @@ spacy-legacy==3.0.12 ; python_version >= "3.8" and python_version < "3.11"
spacy-loggers==1.0.4 ; python_version >= "3.8" and python_version < "3.11"
spacy==3.5.3 ; python_version >= "3.8" and python_version < "3.11"
srsly==2.4.6 ; python_version >= "3.8" and python_version < "3.11"
tensorboard-data-server==0.7.0 ; python_version >= "3.8" and python_version < "3.11"
tensorboard-data-server==0.7.1 ; python_version >= "3.8" and python_version < "3.11"
tensorboard==2.12.3 ; python_version >= "3.8" and python_version < "3.11"
tensorflow-estimator==2.12.0 ; python_version >= "3.8" and python_version < "3.11"
tensorflow-io-gcs-filesystem==0.32.0 ; python_version >= "3.8" and python_version < "3.11" and platform_machine != "arm64" or python_version >= "3.8" and python_version < "3.11" and platform_system != "Darwin"
Expand All @@ -78,14 +79,14 @@ threadpoolctl==3.1.0 ; python_version >= "3.8" and python_version < "3.11"
tokenizers==0.13.3 ; python_version >= "3.8" and python_version < "3.11"
tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11"
tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11"
transformers==4.29.2 ; python_version >= "3.8" and python_version < "3.11"
transformers==4.30.2 ; python_version >= "3.8" and python_version < "3.11"
typer==0.7.0 ; python_version >= "3.8" and python_version < "3.11"
typing-extensions==4.6.1 ; python_version >= "3.8" and python_version < "3.11"
urllib3==1.26.16 ; python_version >= "3.8" and python_version < "3.11"
wasabi==1.1.1 ; python_version >= "3.8" and python_version < "3.11"
werkzeug==2.3.4 ; python_version >= "3.8" and python_version < "3.11"
typing-extensions==4.6.3 ; python_version >= "3.8" and python_version < "3.11"
urllib3==2.0.3 ; python_version >= "3.8" and python_version < "3.11"
wasabi==1.1.2 ; python_version >= "3.8" and python_version < "3.11"
werkzeug==2.3.6 ; python_version >= "3.8" and python_version < "3.11"
wheel==0.40.0 ; python_version >= "3.8" and python_version < "3.11"
wrapt==1.14.1 ; python_version >= "3.8" and python_version < "3.11"
xgboost==1.7.5 ; python_version >= "3.8" and python_version < "3.11"
zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10"
pxtextmining==0.5.3
pxtextmining==0.5.4
22 changes: 15 additions & 7 deletions api/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests
import pandas as pd
import json
import time

"""
To test the API, first in terminal, run this command to launch uvicorn server on http://127.0.0.1:8000
Expand All @@ -10,13 +10,20 @@


def test_json_predictions(json):
response = requests.post("http://127.0.0.1:8000/predict_multilabel", json=json)
# response = requests.post("http://127.0.0.1:8000/predict_multilabel", json=json)
response = requests.post("http://127.0.0.1:8000/predict_sentiment", json=json)
return response


if __name__ == "__main__":
df = pd.read_csv("datasets/hidden/API_test.csv")
df = df[["row_id", "comment_txt"]].copy().set_index("row_id")[:20]
start = time.time()
df = pd.read_csv("datasets/hidden/merged_230612.csv")[["Comment ID", "FFT answer"]][
:2000
]
df = df.rename(
columns={"Comment ID": "row_id", "FFT answer": "comment_txt"}
).dropna()
df = df[["row_id", "comment_txt"]].copy().set_index("row_id")[:1000]
js = []
for i in df.index:
js.append(
Expand All @@ -28,11 +35,12 @@ def test_json_predictions(json):
)
print("The JSON that was sent looks like:")
print(js[:5])
print("The JSON that was sent looks like:")
print(js[:5])
print("The JSON that is returned is:")
returned_json = test_json_predictions(js).json()
print(returned_json)
finish = time.time()
total = finish - start
print(f"Time taken: {total} seconds")
print(returned_json[:10])
# json_object = json.dumps(returned_json, indent=4)
# with open("predictions.json", "w") as outfile:
# outfile.write(json_object)
Loading

0 comments on commit 0b65e33

Please sign in to comment.