Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 25 additions & 101 deletions python/example/crf-ner/ner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"collapsed": true
},
Expand All @@ -16,14 +16,17 @@
"\n",
"from sparknlp.annotator import *\n",
"from sparknlp.common import *\n",
"from sparknlp.base import *"
"from sparknlp.base import *\n",
"\n",
"import time\n",
"import zipfile"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"collapsed": false
"collapsed": true
},
"outputs": [],
"source": [
Expand All @@ -39,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"collapsed": true
},
Expand All @@ -61,7 +64,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"collapsed": true
},
Expand All @@ -87,13 +90,12 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"collapsed": false
"collapsed": true
},
"outputs": [],
"source": [
"import time\n",
"\n",
"documentAssembler = DocumentAssembler()\\\n",
" .setInputCol(\"text\")\\\n",
Expand All @@ -113,6 +115,7 @@
" .setInputCols([\"token\", \"document\"])\\\n",
" .setOutputCol(\"pos\")\n",
"\n",
"#.setEmbeddingsSource(\"glove.6B.100d.txt\", 100, 2)\\\n",
"nerTagger = NerCrfApproach()\\\n",
" .setInputCols([\"sentence\", \"token\", \"pos\"])\\\n",
" .setLabelColumn(\"label\")\\\n",
Expand All @@ -121,7 +124,6 @@
" .setMaxEpochs(20)\\\n",
" .setLossEps(1e-3)\\\n",
" .setDicts([\"ner-corpus/dict.txt\"])\\\n",
" .setEmbeddingsSource(\"glove.6B.100d.txt\", 100, 2)\\\n",
" .setDatasetPath(\"eng.train\")\\\n",
" .setL2(1)\\\n",
" .setC0(1250000)\\\n",
Expand All @@ -145,44 +147,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+---------+--------------------+\n",
"|itemid|sentiment| text|\n",
"+------+---------+--------------------+\n",
"| 1| 0| ...|\n",
"| 2| 0| ...|\n",
"| 3| 1| omg...|\n",
"| 4| 0| .. Omga...|\n",
"| 5| 0| i think ...|\n",
"| 6| 0| or i jus...|\n",
"| 7| 1| Juuuuuuuuu...|\n",
"| 8| 0| Sunny Agai...|\n",
"| 9| 1| handed in m...|\n",
"| 10| 1| hmmmm.... i...|\n",
"| 11| 0| I must thin...|\n",
"| 12| 1| thanks to a...|\n",
"| 13| 0| this weeken...|\n",
"| 14| 0| jb isnt show...|\n",
"| 15| 0| ok thats it ...|\n",
"| 16| 0| <-------- ...|\n",
"| 17| 0| awhhe man.......|\n",
"| 18| 1| Feeling stran...|\n",
"| 19| 0| HUGE roll of ...|\n",
"| 20| 0| I just cut my...|\n",
"+------+---------+--------------------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Load the input data to be annotated\n",
"data = spark. \\\n",
Expand All @@ -196,67 +163,24 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start fitting\n",
"Fitting is ended\n"
]
}
],
"outputs": [],
"source": [
"start = time.time()\n",
"print(\"Start fitting\")\n",
"model = pipeline.fit(data)\n",
"print(\"Fitting is ended\")"
"print(\"Fitting is ended\")\n",
"print (time.time() - start)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+---------+--------------------+--------------------+\n",
"|itemid|sentiment| text| finished_ner|\n",
"+------+---------+--------------------+--------------------+\n",
"| 1| 0| ...|word->is#result->...|\n",
"| 2| 0| ...|word->I#result->O...|\n",
"| 3| 1| omg...|word->omg#result-...|\n",
"| 4| 0| .. Omga...|word->Omgaga.#res...|\n",
"| 5| 0| i think ...|word->i#result->O...|\n",
"| 6| 0| or i jus...|word->or#result->...|\n",
"| 7| 1| Juuuuuuuuu...|word->Juuuuuuuuuu...|\n",
"| 8| 0| Sunny Agai...|word->Sunny#resul...|\n",
"| 9| 1| handed in m...|word->handed#resu...|\n",
"| 10| 1| hmmmm.... i...|word->i#result->O...|\n",
"| 11| 0| I must thin...|word->I#result->O...|\n",
"| 12| 1| thanks to a...|word->thanks#resu...|\n",
"| 13| 0| this weeken...|word->this#result...|\n",
"| 14| 0| jb isnt show...|word->jb#result->...|\n",
"| 15| 0| ok thats it ...|word->ok#result->...|\n",
"| 16| 0| <-------- ...|word->This#result...|\n",
"| 17| 0| awhhe man.......|word->awhhe#resul...|\n",
"| 18| 1| Feeling stran...|word->Feeling#res...|\n",
"| 19| 0| HUGE roll of ...|word->HUGE#result...|\n",
"| 20| 0| I just cut my...|word->I#result->O...|\n",
"+------+---------+--------------------+--------------------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ner_data = model.transform(data)\n",
"ner_data.show()"
Expand All @@ -266,7 +190,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
"collapsed": true
},
"outputs": [],
"source": [
Expand All @@ -293,7 +217,7 @@
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand Down
5 changes: 3 additions & 2 deletions python/sparknlp/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaTransformer, JavaModel, JavaEstimator
from pyspark.ml.param.shared import Param, Params, TypeConverters
from sparknlp.base import JavaRecursiveEstimator

if sys.version_info[0] == 2:
#Needed. Delete once DA becomes an annotator in 1.1.x
Expand Down Expand Up @@ -130,6 +131,7 @@ def setPattern(self, value):
def setLowercase(self, value):
return self._set(lowercase=value)


class RegexMatcher(AnnotatorTransformer):

strategy = Param(Params._dummy(),
Expand Down Expand Up @@ -523,8 +525,7 @@ class NorvigSweetingModel(JavaModel, JavaMLWritable, JavaMLReadable, AnnotatorPr
name = "NorvigSweetingModel"



class NerCrfApproach(JavaEstimator, JavaMLWritable, JavaMLReadable, AnnotatorProperties, AnnotatorWithEmbeddings):
class NerCrfApproach(JavaRecursiveEstimator, JavaMLWritable, JavaMLReadable, AnnotatorProperties, AnnotatorWithEmbeddings):
labelColumn = Param(Params._dummy(),
"labelColumn",
"Column with label per each token",
Expand Down
90 changes: 89 additions & 1 deletion python/sparknlp/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,95 @@
from pyspark import keyword_only
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaTransformer
from pyspark.ml.wrapper import JavaTransformer, JavaEstimator
from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.pipeline import Pipeline, PipelineModel, Estimator, Transformer


class JavaRecursiveEstimator(JavaEstimator):

def _fit_java(self, dataset, pipeline=None):
"""
Fits a Java model to the input dataset.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.DataFrame`
:param params: additional params (overwriting embedded values)
:return: fitted Java model
"""
self._transfer_params_to_java()
if pipeline:
return self._java_obj.recursiveFit(dataset._jdf, pipeline._to_java())
else:
return self._java_obj.fit(dataset._jdf)

def _fit(self, dataset, pipeline=None):
java_model = self._fit_java(dataset, pipeline)
model = self._create_model(java_model)
return self._copyValues(model)

def fit(self, dataset, params=None, pipeline=None):
"""
Fits a model to the input dataset with optional parameters.
:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
:param params: an optional param map that overrides embedded params. If a list/tuple of
param maps is given, this calls fit on each param map and returns a list of
models.
:returns: fitted model(s)
"""
if params is None:
params = dict()
if isinstance(params, (list, tuple)):
models = [None] * len(params)
for index, model in self.fitMultiple(dataset, params):
models[index] = model
return models
elif isinstance(params, dict):
if params:
return self.copy(params)._fit(dataset, pipeline=pipeline)
else:
return self._fit(dataset, pipeline=pipeline)
else:
raise ValueError("Params must be either a param map or a list/tuple of param maps, "
"but got %s." % type(params))


class RecursivePipeline(Pipeline, JavaEstimator):
@keyword_only
def __init__(self, *args, **kwargs):
super(RecursivePipeline, self).__init__(*args, **kwargs)
self._java_obj = self._new_java_obj("com.johnsnowlabs.nlp.RecursivePipeline", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)

def _fit(self, dataset):
stages = self.getStages()
for stage in stages:
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
raise TypeError(
"Cannot recognize a pipeline stage of type %s." % type(stage))
indexOfLastEstimator = -1
for i, stage in enumerate(stages):
if isinstance(stage, Estimator):
indexOfLastEstimator = i
transformers = []
for i, stage in enumerate(stages):
if isinstance(stage, Transformer):
transformers.append(stage)
dataset = stage.transform(dataset)
elif isinstance(stage, JavaRecursiveEstimator):
model = stage.fit(dataset, pipeline=PipelineModel(transformers))
transformers.append(model)
if i < indexOfLastEstimator:
dataset = model.transform(dataset)
else:
model = stage.fit(dataset)
transformers.append(model)
if i < indexOfLastEstimator:
dataset = model.transform(dataset)
if i <= indexOfLastEstimator:
pass
else:
transformers.append(stage)
return PipelineModel(transformers)


class DocumentAssembler(JavaTransformer, JavaMLReadable, JavaMLWritable):
Expand Down
8 changes: 6 additions & 2 deletions src/main/resources/log4j.properties
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
log4j.rootLogger=ERROR, STDOUT
log4j.rootLogger=WARNING, STDOUT
log4j.appender.STDOUT=org.apache.log4j.ConsoleAppender
log4j.appender.STDOUT.layout=org.apache.log4j.PatternLayout
log4j.appender.STDOUT.layout.ConversionPattern=[%5p] %m%n

log4j.logger.AnnotatorLogger=WARNING
log4j.logger.CRF=INFO
log4j.logger.RuleFactory=WARNING
log4j.logger.PerceptronTraining=WARNING
log4j.logger.PragmaticScorer=WARNING
log4j.logger.NorvigApproach=WARNING
log4j.logger.CRF=WARNING
12 changes: 6 additions & 6 deletions src/main/scala/com/johnsnowlabs/nlp/AnnotatorApproach.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.johnsnowlabs.nlp

import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.{Estimator, Model, PipelineModel}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.types.{ArrayType, MetadataBuilder, StructField, StructType}
import org.apache.spark.ml.util.DefaultParamsWritable
Expand All @@ -15,14 +15,14 @@ import org.apache.spark.ml.util.DefaultParamsWritable
*/
abstract class AnnotatorApproach[M <: Model[M]]
extends Estimator[M]
with HasInputAnnotationCols
with HasOutputAnnotationCol
with HasAnnotatorType
with DefaultParamsWritable {
with HasInputAnnotationCols
with HasOutputAnnotationCol
with HasAnnotatorType
with DefaultParamsWritable {

val description: String

def train(dataset: Dataset[_]): M
def train(dataset: Dataset[_], recursivePipeline: Option[PipelineModel] = None): M

def beforeTraining(spark: SparkSession): Unit = {}

Expand Down
Loading