# Medical transcript classification [sklearn]
* Multiclass classification of medical transcript.
* Reference notebook: <https://www.kaggle.com/code/leekahwin/text-classification-using-n-gram-0-8-f1/notebook>
* Dataset: <https://www.kaggle.com/code/leekahwin/text-classification-using-n-gram-0-8-f1/input>

By running this notebook, you’ll create a whole test suite in a few lines of code. The model used here is a Random Forest classification model with the medical transcript dataset. Feel free to use your own model (tabular, text, or LLM).

You’ll learn how to:

* Detect vulnerabilities by scanning the model
* Generate a test suite with domain-specific tests
* Customize your test suite by loading a test from the Giskard catalog
* Upload your model to the Giskard server to:
    * Compare models to decide which one to promote
    * Debug your tests to diagnose issues
    * Share your results and collect business feedback from your team

## Install Giskard

In [None]:
pip install "giskard>=2.0.0b" -U

## Install necessary dependencies

In [None]:
!pip install nltk

## Import libraries

In [1]:
import os
import string
from pathlib import Path
from typing import Iterable
from urllib.request import urlretrieve

import nltk
import pandas as pd
from nltk.corpus import stopwords
from nltk.stem.snowball import SnowballStemmer
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer

import giskard
from giskard import Dataset, Model, GiskardClient, testing

## Download NLTK stopwords corpus

In [None]:
# Download list of english stopwords.
nltk.download('stopwords')

## Define constants

In [2]:
# Constants.
LABELS_LIST = [
    'Neurosurgery',
    'ENT - Otolaryngology',
    'Discharge Summary',
    'General Medicine',
    'Gastroenterology',
    'Neurology',
    'SOAP / Chart / Progress Notes',
    'Obstetrics / Gynecology',
    'Urology'
]

TEXT_COLUMN_NAME = "transcription"
TARGET_COLUMN_NAME = "medical_specialty"

RANDOM_SEED = 8888

# Data.
DATA_URL = os.path.join("ftp://sys.giskard.ai", "pub", "unit_test_resources",
                        "medical_transcript_classification_dataset", "mtsamples.csv")
DATA_PATH = Path.home() / ".giskard" / "medical_transcript_classification_dataset" / "mtsamples.csv"

## Dataset preparation

### Load data

In [3]:
def fetch_from_ftp(url: str, file: Path) -> None:
    """Helper to fetch data from the FTP server."""
    if not file.parent.exists():
        file.parent.mkdir(parents=True, exist_ok=True)

    if not file.exists():
        print(f"Downloading data from {url}")
        urlretrieve(url, file)

    print(f"Data was loaded!")


def load_data() -> pd.DataFrame:
    """Load and initially preprocess data."""
    fetch_from_ftp(DATA_URL, DATA_PATH)

    df = pd.read_csv(DATA_PATH)

    # Drop useless columns.
    df = df.drop(columns=['Unnamed: 0', "description", "sample_name", "keywords"])

    # Trim text.
    df = df.apply(lambda x: x.str.strip())

    # Filter samples by label.
    df = df[df[TARGET_COLUMN_NAME].isin(LABELS_LIST)]

    # Drop rows with no transcript.
    df = df[df[TEXT_COLUMN_NAME].notna()]

    return df

In [4]:
transcript_df = load_data()

Data was loaded!


### Train-test split

In [5]:
X_train, X_test, y_train, y_test = train_test_split(transcript_df[[TEXT_COLUMN_NAME]],
                                                    transcript_df[TARGET_COLUMN_NAME],
                                                    random_state=RANDOM_SEED)

### Wrap dataset with Giskard

In [6]:
raw_data = pd.concat([X_test, y_test], axis=1)
wrapped_data = Dataset(raw_data,
                       name="medical_transcript_dataset",
                       target=TARGET_COLUMN_NAME,
                       column_types={TEXT_COLUMN_NAME: "text"})

## Model training

### Define preprocessing steps

In [7]:
stemmer = SnowballStemmer("english")
stop_words = stopwords.words("english")


def preprocess_text(df: pd.DataFrame) -> pd.DataFrame:
    """Preprocess text."""
    # Lower.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: x.lower())

    # Remove punctuation.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: x.translate(str.maketrans('', '', string.punctuation)))

    # Tokenize.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: x.split())

    # Stem.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(lambda x: [stemmer.stem(word) for word in x])

    # Remove stop-words.
    df[TEXT_COLUMN_NAME] = df[TEXT_COLUMN_NAME].apply(
        lambda x: ' '.join([word for word in x if word not in stop_words]))

    return df


def adapt_vectorizer_input(df: pd.DataFrame) -> Iterable:
    """Adapt input for the vectorizers.

    The problem is that vectorizers accept iterable, not DataFrame, but Series. Thus, we need to ravel dataframe with text have input single dimension.
    Issue reference: https://stackoverflow.com/questions/50665240/valueerror-found-input-variables-with-inconsistent-numbers-of-samples-1-3185"""

    df = df.iloc[:, 0]
    return df


text_preprocessor = FunctionTransformer(preprocess_text)
vectorizer_input_adapter = FunctionTransformer(adapt_vectorizer_input)

### Build estimator

In [8]:
pipeline = Pipeline(steps=[
    ("text_preprocessor", text_preprocessor),
    ("vectorizer_input_adapter", vectorizer_input_adapter),
    ("vectorizer", CountVectorizer(ngram_range=(1, 1))),
    ("estimator", RandomForestClassifier(random_state=RANDOM_SEED))
])

pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)

print(classification_report(y_test, y_pred))

                               precision    recall  f1-score   support

            Discharge Summary       0.43      0.38      0.41        34
         ENT - Otolaryngology       1.00      0.59      0.74        27
             Gastroenterology       0.57      0.81      0.67        43
             General Medicine       0.46      0.61      0.52        69
                    Neurology       0.74      0.66      0.70        53
                 Neurosurgery       0.73      0.79      0.76        24
      Obstetrics / Gynecology       0.85      0.58      0.69        50
SOAP / Chart / Progress Notes       0.40      0.36      0.38        33
                      Urology       0.74      0.68      0.71        38

                     accuracy                           0.61       371
                    macro avg       0.66      0.61      0.62       371
                 weighted avg       0.64      0.61      0.62       371



### Wrap model with Giskard

In [9]:
wrapped_model = Model(pipeline.predict_proba,
                      model_type="classification",
                      name="medical_transcript_classification",
                      feature_names=[TEXT_COLUMN_NAME],
                      classification_labels=pipeline.classes_)

# Validate wrapped model and data.
print(classification_report(y_test, pipeline.classes_[wrapped_model.predict(wrapped_data).raw_prediction]))

                               precision    recall  f1-score   support

            Discharge Summary       0.43      0.38      0.41        34
         ENT - Otolaryngology       1.00      0.59      0.74        27
             Gastroenterology       0.57      0.81      0.67        43
             General Medicine       0.46      0.61      0.52        69
                    Neurology       0.74      0.66      0.70        53
                 Neurosurgery       0.73      0.79      0.76        24
      Obstetrics / Gynecology       0.85      0.58      0.69        50
SOAP / Chart / Progress Notes       0.40      0.36      0.38        33
                      Urology       0.74      0.68      0.71        38

                     accuracy                           0.61       371
                    macro avg       0.66      0.61      0.62       371
                 weighted avg       0.64      0.61      0.62       371



## Scan your model to find vulnerabilities
With the Giskard scan feature, you can detect vulnerabilities in your model, including performance biases, unrobustness, data leakage, stochasticity, underconfidence, ethical issues, and more. For detailed information about the scan feature, please refer to our scan documentation.

In [10]:
results = giskard.scan(wrapped_model, wrapped_data)

Your model is successfully validated.
Running scan…
2023-06-02 17:57:04,714 pid:10102 MainThread giskard.scanner.logger INFO     Running detectors: ['PerformanceBiasDetector', 'TextPerturbationDetector', 'EthicalBiasDetector', 'DataLeakageDetector', 'StochasticityDetector', 'OverconfidenceDetector', 'UnderconfidenceDetector']
Running detector PerformanceBiasDetector…2023-06-02 17:57:04,715 pid:10102 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Running
2023-06-02 17:57:04,715 pid:10102 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Calculating loss
2023-06-02 17:57:05,038 pid:10102 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Loss calculated (took 0:00:00.321848)
2023-06-02 17:57:05,038 pid:10102 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Finding data slices
2023-06-02 17:58:00,022 pid:10102 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: 321 slices found (took 0:00:54.982054)

In [11]:
display(results)

## Generate a test suite from the Scan
The objects produced by the scan can be used as fixtures to generate a test suite that integrate domain-specific issues. To create custom tests, refer to the Test your ML Model page.

In [12]:
test_suite = results.generate_test_suite("My first test suite")
test_suite.run()

Executed 'F1 Score on data slice “`transcription` contains "xyz"”' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x12245e380>, 'dataset': <giskard.datasets.base.Dataset object at 0x122397c70>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x126210cd0>, 'threshold': 0.581266846361186}: 
               Test failed
               Metric: 0.32
               
               
Executed 'F1 Score on data slice “`transcription` contains "subjective"”' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x12245e380>, 'dataset': <giskard.datasets.base.Dataset object at 0x122397c70>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x1261111b0>, 'threshold': 0.581266846361186}: 
               Test failed
               Metric: 0.37
               
               
Executed 'F1 Score on data slice “`transcription` contains "admission"”' with arguments {'model': <giskard.

## Customize your suite by loading objects from the Giskard catalog

The Giskard open source catalog will enable to load:

* Tests such as metamorphic, performance, prediction & data drift, statistical tests, etc
* Slicing functions such as detectors of toxicity, hate, emotion, etc
* Transformation functions such as generators of typos, paraphrase, style tune, etc

For demo purposes, we will load a simple unit test (test_f1) that checks if the test F1 score is above the given threshold. For more examples of tests and functions, refer to the Giskard catalog.

In [14]:
test_suite.add_test(testing.test_f1(model=wrapped_model, dataset=wrapped_data, threshold=0.7)).run()

Executed 'F1 Score on data slice “`transcription` contains "xyz"”' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x12245e380>, 'dataset': <giskard.datasets.base.Dataset object at 0x122397c70>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x126210cd0>, 'threshold': 0.581266846361186}: 
               Test failed
               Metric: 0.32
               
               
Executed 'F1 Score on data slice “`transcription` contains "subjective"”' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x12245e380>, 'dataset': <giskard.datasets.base.Dataset object at 0x122397c70>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x1261111b0>, 'threshold': 0.581266846361186}: 
               Test failed
               Metric: 0.37
               
               
Executed 'F1 Score on data slice “`transcription` contains "admission"”' with arguments {'model': <giskard.

## Upload your suite to the Giskard server

Upload your suite to the Giskard server to:

* Compare models to decide which model to promote
* Debug your tests to diagnose the issues
* Create more domain-specific tests that are integrating business feedback
* Share your results

In [None]:
# Uploading the test suite will automatically save the model, dataset, tests, slicing & transformation functions inside the Giskard UI server
# Create a Giskard client after having install the Giskard server (see documentation)
token = "API_TOKEN"  # Find it in Settings in the Giskard server

client = GiskardClient(
    url="http://localhost:19000",  # URL of your Giskard instance
    token=token
)

my_project = client.create_project("my_project", "PROJECT_NAME", "DESCRIPTION")

# Upload to the current project ✉️
test_suite.upload(client, "my_project")

<div class="alert alert-info">
Connecting Google Colab with the Giskard server

If you are using Google Colab and you want to install the Giskard server **locally**, you can run the Giskard server by executing this line in the terminal of your **local** machine (see the [documentation]https://docs.giskard.ai/en/latest/guides/installation_app/index.html)):

> giskard server start

Once the Giskard server is running, from the same terminal on your **local** machine, you can run:

> giskard server expose

This will provide you with the code snippets that you can copy and paste into your Colab notebook to establish a connection with your locally installed Giskard server
</div>