### ClassifierDL for Multi-class Text Classification

[link](https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/training/english/classification/ClassifierDL_Train_multi_class_news_category_classifier.ipynb)

In [1]:
from os import environ

java8_path: str = "C:\Java\jdk1.8.0_191"

environ["JAVA_HOME"] =  java8_path
environ["PATH"] = environ["JAVA_HOME"] + "/bin;" + environ["PATH"]

! java -version

java version "1.8.0_191"
Java(TM) SE Runtime Environment (build 1.8.0_191-b12)
Java HotSpot(TM) 64-Bit Server VM (build 25.191-b12, mixed mode)


In [2]:
# --- Standard Library --- #
import re
import urllib.request as ureq
from functools import partial
import concurrent.futures as ccf


#  --- Third-party libraries --- #
import sparknlp
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession, dataframe

# https://github.com/JohnSnowLabs/spark-nlp/blob/master/python/sparknlp/__init__.py
from sparknlp.base import DocumentAssembler
from sparknlp.annotator import (
    ClassifierDLApproach,
    UniversalSentenceEncoder,
)
# from sparknlp.transformers import *
# from sparknlp.common import *


Spark NLP version:  3.1.1
Apache Spark version:  3.1.2


<h1 style="text-align: center;">Create Spark Session</h1>
<h4 style="text-align: center;color: white;">Set attributes, as appropriate.</h4>

In [None]:

sparky = (SparkSession.builder
            .appName("Spark NLP")
            .master("local[4]")
            .config("spark.driver.memory", "166")
            .config("spark.driver.maxResultSize", "0")
            .config("spark.kryoserializer.buffer.max", "2000M")
            .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:3.1.1")
            .getOrCreate()
           )

print("Spark NLP version: ", sparknlp.version())
print("Apache Spark version: ", sparky.version)

In [None]:
# Get training and testing news category dataset
urls_and_names = [
    {
        "name": "train",
        "url": "https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_train.csv",
        "localpath": r"data\news_category_train.csv",
    },
    {
        "name": "test",
        "url": "https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_test.csv",
        "localpath": r"data\news_category_test.csv",
    }    
]


# Partial function for data retrieval.
retriever = partial(ureq.urlretrieve, url = None, filename = None)

# Get that data.
with ccf.ThreadPoolExecutor(max_workers = 4) as executor:
    [executor.submit(retriever, i["url"], i["localpath"]) for i in urls_and_names]

# Alternative on Linux
# !wget -O data\news_category_train.csv https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_train.csv
# !wget -O data\news_category_test.csv https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_test.csv

In [None]:
# --- Create IPython magic function for viewing data sample --- #

# View sample of data
from sys import platform

if "win" in platform:
    from IPython.core.magic import register_line_magic

    @register_line_magic
    def header(line):
        """IPython magic method to read first N lines of a file if
        using Windows platform.
        
        Example Usage
        -------------
        %header data\news_category_train.csv
        %header data\news_category_train.csv 5
        """
        lines = line.split(" ")
        filepath = lines[0]
        
        # Set number of lines variable default and update
        # if valid data present.
        n_lines: int = 10
        if len(lines) == 2:
            n_lines = int(lines[1])

        output: list = []
        with open(filepath) as f:
            # first N lines
            for _ in range(n_lines):
                output.append(f.readline())
        print("".join(output))

    %header data\news_category_train.csv 5
    del header

else:
    !head data\news_category_train.csv

## Create Spark dataframe with training data

In [None]:
# Get training data local path into variable and pass to Spark chain.
train_path = [i["localpath"] for i in urls_and_names if i["name"]=="train"][0]

# trainData = sparky.read.option("header", True).csv(train_path)
trainDF = sparky.read.csv(train_path, header = True)

In [None]:
# Show sample data to confirm successful creation.
print(f"Spark dataframe row count: {trainDF.count():,}\n")
print("DataFrame metadata:\n")
trainDF.printSchema()
trainDF.show()

In [None]:
# -- Use SQL to peruse data

# This is a wrapper method to help view SQL data in Spark with simple queries.
SQL_PATTERN: str = r"""
    (?:\n|\s)?
    FROM\s([a-z0-9_]+)
    (?:\n|\s)?
    """

p = re.compile(SQL_PATTERN, flags = re.I | re.X)

def query_view_df(func):
    def inner(*args, **kwargs):
        """Args list should include a query string and pandas.DataFrame objects."""
        _query: str = None
        _df = None
        _viewname: str = ""
        # Check argument count
        if len(args) == 2:
            # Assert data type of one argument
            # Set variables accordingly.
            if isinstance(args[0], str):
                _query = args[0]
                _df = args[1]
            else:
                _df = args[0]
                _query = args[1]
            
            # Sanity check.
            assert isinstance(_df, dataframe.DataFrame), "PySpark.sql.dataframe.DataFrame expected!"

            
            # Search for viewname within SQL query.
            # If found, update _viewname variable.
            # sql_pattern = r"(?:\n|\s)?FROM\s([a-z0-9_]+)(?:\n|\s)?"
            # res = re.search(sql_pattern, _query, flags=re.I)
            res = p.search(_query)
            if res:
                _viewname = res.group(1)

            # Create temporary view, run SQL, destroy temporary view.
            _df.createOrReplaceTempView(_viewname)
            spark.sql(_query).show()
            spark.catalog.dropTempView(_viewname)
            
    return inner


@query_view_df
def run_query_df(query_string, dataframe) -> None:
    """Return output for Spark SQL view."""
    return query_string.strip(), dataframe
    
# trainDF.createOrReplaceTempView("dataview")
# sparky.sql("SELECT category FROM dataview GROUP BY category").show()

In [None]:
# type(trainDF)
# from pyspark.sql import dataframe
assert isinstance(trainDF, dataframe.DataFrame), "Error!"

In [None]:
run_query_df("SELECT category FROM dataview GROUP BY category", trainDF)

In [None]:
from pyspark.ml import Pipeline

# https://github.com/JohnSnowLabs/spark-nlp/blob/master/python/sparknlp/__init__.py
from sparknlp.annotator import ClassifierDLApproach, UniversalSentenceEncoder
# from sparknlp.transformers import *
# from sparknlp.common import *
from sparknlp.base import DocumentAssembler

In [None]:
doc_ = (DocumentAssembler()
        .setInputCol("description")
        .setOutputCol("document")
       )

# use_ = (UniversalSentenceEncoder.pretrained()
#         .setInputCols("document")
#         .setOutputCol("sentence_embeddings")
#        )

# clsdl_ = (ClassifierDLApproach().setInputCols("sentence_embeddings")
#           .setOutputCol("class")
#           .setLabelColumn("category")
#           .setMaxEpochs(5)
#           .setEnableOutputLogs(True)
#          )

# pipeline = Pipeline(
#     stages = [
#         doc_,
#         use_,
#         clsdl_
#     ])