# Medical transcript classification [sklearn]
* Multiclass classification of medical transcript.
* Reference notebook: <https://www.kaggle.com/code/leekahwin/text-classification-using-n-gram-0-8-f1/notebook>
* Dataset: <https://www.kaggle.com/code/leekahwin/text-classification-using-n-gram-0-8-f1/input>

## Install necessary dependencies

In [None]:
!pip install nltk

## Import libraries

In [None]:
import os
import string
from typing import Iterable

import nltk
import numpy as np
import pandas as pd
from nltk.corpus import stopwords
from sklearn.pipeline import Pipeline
from nltk.stem.snowball import SnowballStemmer
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import FunctionTransformer
from sklearn.feature_extraction.text import CountVectorizer

import giskard
from giskard import Dataset, Model, GiskardClient
from giskard.client.giskard_client import GiskardError

## Download NLTK stopwords corpus

In [None]:
# Download list of english stopwords.
nltk.download('stopwords')

## Define constants

In [None]:
# Constants.
LABELS_LIST = [
    'Neurosurgery',
    'ENT - Otolaryngology',
    'Discharge Summary',
    'General Medicine',
    'Gastroenterology',
    'Neurology',
    'SOAP / Chart / Progress Notes',
    'Obstetrics / Gynecology',
    'Urology'
]

TEXT_COLUMN_NAME = "transcription"
TARGET_COLUMN_NAME = "medical_specialty"

RANDOM_SEED = 8888

# Giskard creds.
GISKARD_URL = "http://localhost:9000"
GISKARD_TOKEN = ""
GISKARD_PROJECT_KEY = "medical_transcript_classification"

# Paths.
PATH_DATA = os.path.join(".", "datasets", "medical_transcript_classification_dataset", "mtsamples.csv")

## Load data

In [None]:
def load_data() -> pd.DataFrame:
    """Load and initially preprocess data."""
    df = pd.read_csv(PATH_DATA)

    # Drop useless columns.
    df = df.drop(columns=['Unnamed: 0', "description", "sample_name", "keywords"])

    # Trim text.
    df = df.apply(lambda x: x.str.strip())

    # Filter samples by label.
    df = df[df[TARGET_COLUMN_NAME].isin(LABELS_LIST)]

    # Drop rows with no transcript.
    df = df[df[TEXT_COLUMN_NAME].notna()]

    return df

transcript_df = load_data()

## Train-test split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(transcript_df[[TEXT_COLUMN_NAME]], transcript_df[TARGET_COLUMN_NAME],
                                                    random_state=RANDOM_SEED)

## Wrap dataset with giskard

In [None]:
raw_data = pd.concat([X_test, y_test], axis=1)
wrapped_data = Dataset(raw_data,
                       name="medical_transcript_dataset",
                       target=TARGET_COLUMN_NAME,
                       column_types={TEXT_COLUMN_NAME: "text"})

## Define preprocessing steps

In [None]:
stemmer = SnowballStemmer("english")
stop_words = stopwords.words("english")

def preprocess_text(df: pd.DataFrame) -> pd.DataFrame:
    """Preprocess text."""
    # Lower.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: x.lower())

    # Remove punctuation.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: x.translate(str.maketrans('', '', string.punctuation)))

    # Tokenize.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: x.split())

    # Stem.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: [stemmer.stem(word) for word in x])

    # Remove stop-words.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: ' '.join([word for word in x if word not in stop_words]))

    return df

text_preprocessor = FunctionTransformer(preprocess_text)

In [None]:
def adapt_vectorizer_input(df: pd.DataFrame) -> Iterable:
    """Adapt input for the vectorizers.

    The problem is that vectorizers accept iterable, not DataFrame, but Series. Thus, we need to ravel dataframe with text have input single dimension.
    Issue reference: https://stackoverflow.com/questions/50665240/valueerror-found-input-variables-with-inconsistent-numbers-of-samples-1-3185"""

    df = df.iloc[:, 0]
    return df

vectorizer_input_adapter = FunctionTransformer(adapt_vectorizer_input)

## Define final pipeline

In [None]:
pipeline = Pipeline(steps=[
    ("text_preprocessor", text_preprocessor),
    ("vectorizer_input_adapter", vectorizer_input_adapter),
    ("vectorizer", CountVectorizer(ngram_range=(1, 1))),
    ("estimator", RandomForestClassifier(random_state=RANDOM_SEED))
])

## Fit and test estimator

In [None]:
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)

print(classification_report(y_test, y_pred))

## Define prediction function

In [None]:
def prediction_function(df: pd.DataFrame) -> np.ndarray:
    return pipeline.predict_proba(df)

## Wrap model with giskard

In [None]:
wrapped_model = Model(prediction_function,
                      model_type="classification",
                      name="medical_transcript_classification",
                      feature_names=[TEXT_COLUMN_NAME],
                      classification_labels=pipeline.classes_)

# Validate wrapped model and data.
print(classification_report(y_test, pipeline.classes_[wrapped_model.predict(wrapped_data).raw_prediction]))

## Perform model scan

In [None]:
scanning_results = giskard.scan(wrapped_model, wrapped_data)

In [None]:
display(scanning_results)

## Upload model and dataset to the giskard UI platform

In [None]:
# Init giskard client.
client = GiskardClient(GISKARD_URL, GISKARD_TOKEN)

# Create or fetch a project by its key.
try:
    project = client.create_project(GISKARD_PROJECT_KEY,
                                    name="MEDICAL_TRANSCRIPTS_CLASSIFICATION",
                                    description="Multiclass classification of the diagnosis based on medical transcript text.")
except GiskardError as e:
    print(f"Project with key {GISKARD_PROJECT_KEY} already exists. Trying to get it.")
    project = client.get_project(GISKARD_PROJECT_KEY)

# Upload model and dataset.
model_id = wrapped_model.upload(client, GISKARD_PROJECT_KEY)
dataset_id = wrapped_data.upload(client, GISKARD_PROJECT_KEY)