![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/Spark_NLP_Udemy_MOOC/Healthcare_NLP/GenericLogRegClassifierApproach.ipynb)

# **GenericLogRegClassifierApproach**

This notebook will cover the different parameters and usages of `GenericLogRegClassifierApproach` annotator.

**üìñ Learning Objectives:**

1. Understand how to use `GenericLogRegClassifierApproach`.

2. Become comfortable using the different parameters of the annotator.




**üîó Helpful Links:**

- Documentation : [GenericLogRegClassifierApproach](https://nlp.johnsnowlabs.com/docs/en/licensed_annotators#genericlogregclassifier)

- For extended examples of usage, see the [Spark NLP Workshop](https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/8.Generic_Classifier.ipynb)

- Python Docs : [GenericLogRegClassifierApproach](https://nlp.johnsnowlabs.com/licensed/api/python/reference/autosummary/sparknlp_jsl/annotator/classification/generic_log_reg_classifier/index.html)

- Scala Docs : [GenericLogRegClassifierApproach](https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/8.Generic_Classifier.ipynb)



## **üìú Background**


`GenericLogRegClassifier` is a derivative of GenericClassifier which implements a multinomial logistic regression. This is a single layer neural network with the logistic function at the output. The input to the model is FeatureVector and the output is category annotations with labels and corresponding confidence scores varying between 0 and 1.


## **üé¨ Colab Setup**

In [1]:
# Install the johnsnowlabs library to access Spark-NLP for Healthcare
! pip install -q johnsnowlabs

[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m265.2/265.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m310.8/310.8 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m565.0/565.0 kB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m676.2/676.2 kB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ

In [2]:
from google.colab import files
print('Please Upload your John Snow Labs License using the button below')
license_keys = files.upload()

Please Upload your John Snow Labs License using the button below


Saving 5.3.3.spark_nlp_for_healthcare.json to 5.3.3.spark_nlp_for_healthcare.json


In [3]:
from johnsnowlabs import nlp, medical

# After uploading your license run this to install all licensed Python Wheels and pre-download Jars the Spark Session JVM

nlp.install()

üëå Detected license file /content/5.3.3.spark_nlp_for_healthcare.json
üö® Outdated Medical Secrets in license file. Version=5.3.3 but should be Version=5.3.2
üö® Outdated OCR Secrets in license file. Version=5.1.2 but should be Version=5.3.2
üë∑ Trying to install compatible secrets. Use nlp.settings.enforce_versions=False if you want to install outdated secrets.
üìã Stored John Snow Labs License in /root/.johnsnowlabs/licenses/license_number_0_for_Spark-Healthcare_Spark-OCR.json
üë∑ Setting up  John Snow Labs home in /root/.johnsnowlabs, this might take a few minutes.
Downloading üêç+üöÄ Python Library spark_nlp-5.3.2-py2.py3-none-any.whl
Downloading üêç+üíä Python Library spark_nlp_jsl-5.3.2-py3-none-any.whl
Downloading ü´ò+üöÄ Java Library spark-nlp-assembly-5.3.2.jar
Downloading ü´ò+üíä Java Library spark-nlp-jsl-5.3.2.jar
üôÜ JSL Home setup in /root/.johnsnowlabs
üëå Detected license file /content/5.3.3.spark_nlp_for_healthcare.json
üë∑ Trying to install compatibl

In [4]:
# Automatically load license data and start a session with all jars user has access to
spark = nlp.start()

üëå Detected license file /content/5.3.3.spark_nlp_for_healthcare.json
üë∑ Trying to install compatible secrets. Use nlp.settings.enforce_versions=False if you want to install outdated secrets.
üëå Launched [92mcpu optimized[39m session with with: üöÄSpark-NLP==5.3.2, üíäSpark-Healthcare==5.3.2, running on ‚ö° PySpark==3.4.0


In [5]:
spark

In [6]:
import pandas as pd
from sklearn.metrics import classification_report

## **üñ®Ô∏è Input/Output Annotation Types**

- Input: `FEATURE_VECTOR`

- Output: `CATEGORY`

## **üîé Parameters**


- `inputCols`: The name of the columns containing the input annotations. It can read either a String column or an Array.
- `outputCol`: The name of the column in Document type that is generated. We can specify only one column here.


All the parameters can be set using the corresponding set method in camel case. For example, `.setInputcols()`.

### `inputCols` and `outputCol`

Define the column names containing the `SENTENCE_EMBEDDINGS` or `FeatureVector`  annotations needed as input to the `GenericLogRegClassifierApproach ` and the name of the new column containg the identified entities.

Let's define a pipeline to process raw texts into `FeatureVector` annotations:

### Data Preprocessing

In [7]:
#downloading sample datasets
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/ADE_Corpus_V2/ADE-NEG.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/ADE_Corpus_V2/DRUG-AE.rel

**ADE Negative Dataset**

In [8]:
df_neg= pd.read_csv("ADE-NEG.txt", header=None, delimiter="\t", names=["col1"])
df_neg.head()

Unnamed: 0,col1
0,6460590 NEG Clioquinol intoxication occurring ...
1,"8600337 NEG ""Retinoic acid syndrome"" was preve..."
2,8402502 NEG BACKGROUND: External beam radiatio...
3,"8700794 NEG Although the enuresis ceased, she ..."
4,17662448 NEG A 42-year-old woman had uneventfu...


In [9]:
df_neg['text'] =  df_neg.col1.str.split('NEG').str[1]
df_neg["category"] = "neg"
df_neg= df_neg[["text", "category"]]
df_neg.head()

Unnamed: 0,text,category
0,Clioquinol intoxication occurring in the trea...,neg
1,"""Retinoic acid syndrome"" was prevented with s...",neg
2,BACKGROUND: External beam radiation therapy o...,neg
3,"Although the enuresis ceased, she developed t...",neg
4,A 42-year-old woman had uneventful bilateral ...,neg


**ADE Positive Dataset**

In [10]:
df_pos= pd.read_csv("DRUG-AE.rel", header=None, delimiter="|")
df_pos.head()

Unnamed: 0,0,1,2,3,4,5,6,7
0,10030778,Intravenous azithromycin-induced ototoxicity.,ototoxicity,43,54,azithromycin,22,34
1,10048291,"Immobilization, while Paget's bone disease was...",increased calcium-release,960,985,dihydrotachysterol,908,926
2,10048291,Unaccountable severe hypercalcemia in a patien...,hypercalcemia,31,44,dihydrotachysterol,94,112
3,10082597,METHODS: We report two cases of pseudoporphyri...,pseudoporphyria,620,635,naproxen,646,654
4,10082597,METHODS: We report two cases of pseudoporphyri...,pseudoporphyria,620,635,oxaprozin,659,668


In [11]:
df_pos["category"]= "pos"
df_pos.rename(columns={1: "text"}, inplace=True)
df_pos= df_pos[["text", "category"]]
df_pos.head()

Unnamed: 0,text,category
0,Intravenous azithromycin-induced ototoxicity.,pos
1,"Immobilization, while Paget's bone disease was...",pos
2,Unaccountable severe hypercalcemia in a patien...,pos
3,METHODS: We report two cases of pseudoporphyri...,pos
4,METHODS: We report two cases of pseudoporphyri...,pos


**Merging Positive and Negative dataset**

In [12]:
ade_df= pd.concat([df_neg, df_pos])
ade_df.head()

Unnamed: 0,text,category
0,Clioquinol intoxication occurring in the trea...,neg
1,"""Retinoic acid syndrome"" was prevented with s...",neg
2,BACKGROUND: External beam radiation therapy o...,neg
3,"Although the enuresis ceased, she developed t...",neg
4,A 42-year-old woman had uneventful bilateral ...,neg


In [13]:
ade_df["category"].value_counts()

category
neg    16695
pos     6821
Name: count, dtype: int64

In [14]:
ade_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 23516 entries, 0 to 6820
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   text      23516 non-null  object
 1   category  23516 non-null  object
dtypes: object(2)
memory usage: 551.2+ KB


We take 30% of the data to make a faster run. You can use all data for better scores.

In [15]:
spark_df = spark.createDataFrame(ade_df).sample(0.3, 3) # limit the data

trainingData, testData = spark_df.randomSplit([0.8, 0.2], seed = 100)

print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))

Training Dataset Count: 5617
Test Dataset Count: 1390


In [16]:
spark_df.groupBy("category").count().show()

+--------+-----+
|category|count|
+--------+-----+
|     neg| 4939|
|     pos| 2068|
+--------+-----+



In [17]:
spark_df.printSchema()

root
 |-- text: string (nullable = true)
 |-- category: string (nullable = true)



In [18]:
spark_df.head(3)

[Row(text=' Clioquinol intoxication occurring in the treatment of acrodermatitis enteropathica with reference to SMON outside of Japan.', category='neg'),
 Row(text=' A 42-year-old woman had uneventful bilateral laser-assisted subepithelial keratectomy (LASEK) to correct myopia.', category='neg'),
 Row(text=' A 16-year-old girl with erosive, polyarticular JRA showed no detectable change in her articular disease following nine exchanges.', category='neg')]

### 100 Dimension Healthcare Embeddings (embeddings_healthcare_100d)



Now we will extract [healthcare_100d embeddings](https://nlp.johnsnowlabs.com/2020/05/29/embeddings_healthcare_100d_en.html) and use it in the classificaiton model training.

In [19]:
document_assembler = nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

tokenizer = nlp.Tokenizer() \
    .setInputCols(["document"]) \
    .setOutputCol("token")

word_embeddings = nlp.WordEmbeddingsModel.pretrained("embeddings_healthcare_100d","en","clinical/models")\
    .setInputCols(["document","token"])\
    .setOutputCol("word_embeddings")

sentence_embeddings = nlp.SentenceEmbeddings() \
    .setInputCols(["document", "word_embeddings"]) \
    .setOutputCol("sentence_embeddings") \
    .setPoolingStrategy("AVERAGE")

embeddings_pipeline = nlp.Pipeline(
    stages = [
        document_assembler,
        tokenizer,
        word_embeddings,
        sentence_embeddings,

    ])

embeddings_healthcare_100d download started this may take some time.
Approximate size to download 475.8 MB
[OK!]


In [20]:
trainingData_with_embeddings = embeddings_pipeline.fit(trainingData).transform(trainingData)
trainingData_with_embeddings = trainingData_with_embeddings.select("text","category","sentence_embeddings")
trainingData_with_embeddings.show(2,truncate=60)

+------------------------------------------------------------+--------+------------------------------------------------------------+
|                                                        text|category|                                         sentence_embeddings|
+------------------------------------------------------------+--------+------------------------------------------------------------+
| "Syndrome malin"-like symptoms probably due to interacti...|     neg|[{sentence_embeddings, 0, 109,  "Syndrome malin"-like sym...|
| 'Bail-out' bivalirudin use in patients with thrombotic c...|     neg|[{sentence_embeddings, 0, 150,  'Bail-out' bivalirudin us...|
+------------------------------------------------------------+--------+------------------------------------------------------------+
only showing top 2 rows



In [21]:
testData_with_embeddings = embeddings_pipeline.fit(testData).transform(testData)
testData_with_embeddings = testData_with_embeddings.select("text","category","sentence_embeddings")
testData_with_embeddings.show(2,truncate=60)

+------------------------------------------------------------+--------+------------------------------------------------------------+
|                                                        text|category|                                         sentence_embeddings|
+------------------------------------------------------------+--------+------------------------------------------------------------+
| (2) Rehabilitation of a 29-year-old man with a 7-year hi...|     neg|[{sentence_embeddings, 0, 176,  (2) Rehabilitation of a 2...|
| 2-Chlordeoxyadenosine (2-CdA) is an antineoplastic/immun...|     neg|[{sentence_embeddings, 0, 170,  2-Chlordeoxyadenosine (2-...|
+------------------------------------------------------------+--------+------------------------------------------------------------+
only showing top 2 rows



In [22]:
testData_with_embeddings.printSchema()

root
 |-- text: string (nullable = true)
 |-- category: string (nullable = true)
 |-- sentence_embeddings: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)



In [23]:
testData_with_embeddings.select(testData_with_embeddings.sentence_embeddings.embeddings).show(3,truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                          sentence_embeddings.embeddings|
+------------------------------------------------------------------------------------------------------------------------+
|[[-0.07506646, 0.053498756, 0.20358036, 0.22310139, -0.12403396, -0.07148757, 0.07359774, -0.07931808, -0.058168065, ...|
|[[-0.010381512, 0.082957536, 0.10597669, 0.22413144, -0.17930073, -0.038972173, -0.017218085, -0.08750686, -0.0118597...|
|[[0.061766334, -0.050023016, 0.24404901, 0.18501845, -0.14053237, -0.0817295, -0.06483702, -0.1373563, 0.0844157, 0.0...|
+------------------------------------------------------------------------------------------------------------------------+
only showing top 3 rows



In [24]:
log_folder="ADE_logs_healthcare_100d"
!mkdir -p $log_folder

### GenericLogRegClassifier

In [25]:
!pip install -q tensorflow==2.11.0
!pip install -q tensorflow-addons

[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m588.3/588.3 MB[0m [31m767.5 kB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.7/1.7 MB[0m [31m73.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.1/1.1 MB[0m [31m55.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m6.0/6.0 MB[0m [31m79.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m439.2/439.2 kB[0m [31m36.9 MB/s[0m eta [36m0:00:00

In [26]:
#from sparknlp_jsl.annotator import TFGraphBuilder

graph_folder = "gc_graph"

gc_logreg_graph_builder = medical.TFGraphBuilder()\
    .setModelName("logreg_classifier")\
    .setInputCols(["feature_vector"]) \
    .setLabelColumn("category")\
    .setGraphFolder(graph_folder)\
    .setGraphFile("log_reg_graph.pb")

GenericLogRegClassifier needs outputs from FeaturesAssembler. The FeaturesAssembler is used to collect features from different columns or an embeddings column.

The GenericLogRegClassifier takes FEATURE_VECTOR annotations as input, classifies them and outputs CATEGORY annotations.

In [27]:
features_asm = medical.FeaturesAssembler()\
    .setInputCols(["sentence_embeddings"])\
    .setOutputCol("feature_vector")

gen_clf = medical.GenericLogRegClassifierApproach()\
    .setLabelColumn("category")\
    .setInputCols("feature_vector")\
    .setOutputCol("prediction")\
    .setModelFile(f"{graph_folder}/log_reg_graph.pb")\
    .setEpochsNumber(20)\
    .setBatchSize(128)\
    .setLearningRate(0.01)\
    .setOutputLogsPath(log_folder)\
    .setDropout(0.1)\
    .setFixImbalance(True)\
    # .setValidationSplit(0.1)\
    # .setFeatureScaling() Possible values are 'zscore', 'minmax' or empty (no scaling)


clf_Pipeline = nlp.Pipeline(stages=[
    features_asm,
    gc_logreg_graph_builder,
    gen_clf])


In [28]:
gen_clf.extractParamMap()

{Param(parent='GenericLogRegClassifierApproach_de36ab6999b0', name='lazyAnnotator', doc='Whether this AnnotatorModel acts as lazy in RecursivePipelines'): False,
 Param(parent='GenericLogRegClassifierApproach_de36ab6999b0', name='labelColumn', doc='Column with one label per document'): 'category',
 Param(parent='GenericLogRegClassifierApproach_de36ab6999b0', name='batchSize', doc='Size for each batch in the optimization process'): 128,
 Param(parent='GenericLogRegClassifierApproach_de36ab6999b0', name='epochsN', doc='Number of epochs for the optimization process'): 20,
 Param(parent='GenericLogRegClassifierApproach_de36ab6999b0', name='learningRate', doc='Learning rate for the optimization process'): 0.01,
 Param(parent='GenericLogRegClassifierApproach_de36ab6999b0', name='dropout', doc='Dropout at the output of each layer'): 0.1,
 Param(parent='GenericLogRegClassifierApproach_de36ab6999b0', name='fixImbalance', doc='A flag indicating whenther to balance the trainig set'): True,
 Param

In [29]:
generic_model_hc100 = clf_Pipeline.fit(trainingData_with_embeddings)

TF Graph Builder configuration:
Model name: logreg_classifier
Graph folder: gc_graph
Graph file name: log_reg_graph.pb
Build params: {'input_dim': 100, 'output_dim': 2, 'hidden_layers': [], 'output_act': 'sigmoid'}
logreg_classifier graph exported to gc_graph/log_reg_graph.pb


In [30]:
!cat $log_folder/GenericLogRegClassifierApproach*

Training 20 epochs
Epoch 1/20	0.36s	Loss: 26.554672	ACC: 0.6926758
Epoch 2/20	0.08s	Loss: 24.614397	ACC: 0.7155933
Epoch 3/20	0.07s	Loss: 23.559282	ACC: 0.7293469
Epoch 4/20	0.07s	Loss: 23.000538	ACC: 0.7385924
Epoch 5/20	0.07s	Loss: 22.46028	ACC: 0.7422268
Epoch 6/20	0.09s	Loss: 22.312366	ACC: 0.7477782
Epoch 7/20	0.07s	Loss: 22.094908	ACC: 0.7489378
Epoch 8/20	0.07s	Loss: 21.911024	ACC: 0.7502404
Epoch 9/20	0.09s	Loss: 21.764406	ACC: 0.75319916
Epoch 10/20	0.08s	Loss: 21.474533	ACC: 0.75598043
Epoch 11/20	0.06s	Loss: 21.450632	ACC: 0.7571054
Epoch 12/20	0.06s	Loss: 21.443283	ACC: 0.7607634
Epoch 13/20	0.06s	Loss: 21.435198	ACC: 0.75555456
Epoch 14/20	0.06s	Loss: 21.29784	ACC: 0.7556379
Epoch 15/20	0.05s	Loss: 21.273758	ACC: 0.7620644
Epoch 16/20	0.06s	Loss: 21.19557	ACC: 0.7645266
Epoch 17/20	0.05s	Loss: 21.195524	ACC: 0.7652966
Epoch 18/20	0.05s	Loss: 21.248274	ACC: 0.76114213
Epoch 19/20	0.06s	Loss: 21.103481	ACC: 0.7673801
Epoch 20/20	0.05s	Loss: 21.037926	ACC: 0.76636195
Training

In [31]:
preds = generic_model_hc100.transform(testData_with_embeddings)

In [32]:
preds.printSchema()
preds.select(preds.prediction).show(5, truncate=False)
preds.select(preds.category, preds.prediction.result).show(5, truncate=False)

root
 |-- text: string (nullable = true)
 |-- category: string (nullable = true)
 |-- sentence_embeddings: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- feature_vector: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |

In [33]:
preds_df = preds.select('category','prediction.result').toPandas()
preds_df['result'] = preds_df.result.apply(lambda x : x[0])

print (classification_report(preds_df['category'], preds_df['result']))

              precision    recall  f1-score   support

         neg       0.77      0.92      0.84       972
         pos       0.67      0.38      0.48       418

    accuracy                           0.76      1390
   macro avg       0.72      0.65      0.66      1390
weighted avg       0.74      0.76      0.73      1390

