From 3d52a7c9a6737b858c07d457c33d889109af50c2 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Tue, 26 Mar 2024 11:52:56 +0000 Subject: [PATCH 1/2] added Nomic Scala api and tests --- .../scala/com/johnsnowlabs/ml/ai/Nomic.scala | 151 +++++++ .../nlp/embeddings/NomicEmbeddings.scala | 422 ++++++++++++++++++ .../embeddings/NomicEmbeddingsTestSpec.scala | 160 +++++++ 3 files changed, 733 insertions(+) create mode 100644 src/main/scala/com/johnsnowlabs/ml/ai/Nomic.scala create mode 100644 src/main/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddings.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Nomic.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Nomic.scala new file mode 100644 index 000000000000..8f3df0cdc6b5 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Nomic.scala @@ -0,0 +1,151 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} +import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} +import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ + +import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.JavaConverters._ + +private[johnsnowlabs] class Nomic( + val onnxWrapper: Option[OnnxWrapper], + sentenceStartTokenId: Int, + sentenceEndTokenId: Int) + extends Serializable { + + protected val logger: Logger = LoggerFactory.getLogger("NOMIC_EMBEDDINGS") + + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + + /** Get sentence embeddings for a batch of sentences + * @param batch + * batch of sentences + * @return + * sentence embeddings + */ + private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = { + val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max + val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength)) + val embeddings = + getSentenceEmbeddingFromOnnx(paddedBatch, maxSentenceLength) + embeddings + } + + private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = { + if (arr.length >= maxLength) { + arr + } else { + arr ++ Array.fill(maxLength - arr.length)(0) + } + } + + private def getSentenceEmbeddingFromOnnx( + batch: Seq[Array[Int]], + maxSentenceLength: Int): Array[Array[Float]] = { + + val inputIds = batch.map(x => x.map(x => x.toLong)).toArray + val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray + + val (session: OrtSession, env: OrtEnvironment) = + onnxWrapper.get.getSession(onnxSessionOptions) + + val tokenTensors = OnnxTensor.createTensor(env, inputIds) + val maskTensors = OnnxTensor.createTensor(env, attentionMask) + val inputs: java.util.Map[String, OnnxTensor] = + Map( + OnnxSignatures.encoderInputIDs -> tokenTensors, + OnnxSignatures.encoderAttentionMask -> maskTensors).asJava + val encoderResults = session.run(inputs) + + val encoderStateBuffer = + try { + val encoderStateTensor = encoderResults + .get(OnnxSignatures.encoderOutput) + .get() + .asInstanceOf[OnnxTensor] + + val shape = encoderStateTensor.getInfo.getShape + encoderStateTensor.getFloatBuffer + .array() + .grouped(shape(1).toInt) + .toArray + } finally { + if (encoderResults != null) encoderResults.close() + } + + tokenTensors.close() + maskTensors.close() + + encoderStateBuffer + } + + /** Predict sentence embeddings for a batch of sentences + * @param sentences + * sentences + * @param tokenizedSentences + * tokenized sentences + * @param batchSize + * batch size + * @param maxSentenceLength + * max sentence length + * @return + */ + def predict( + sentences: Seq[Annotation], + tokenizedSentences: Seq[WordpieceTokenizedSentence], + batchSize: Int, + maxSentenceLength: Int): Seq[Annotation] = { + + tokenizedSentences + .zip(sentences) + .zipWithIndex + .grouped(batchSize) + .toArray + .flatMap { batch => + val tokensBatch = batch.map(x => x._1._1.tokens) + val tokens = tokensBatch.map(x => + Array(sentenceStartTokenId) ++ x + .map(y => y.pieceId) + .take(maxSentenceLength - 2) ++ Array(sentenceEndTokenId)) + + val sentenceEmbeddings = getSentenceEmbedding(tokens) + + batch.zip(sentenceEmbeddings).map { case (sentence, vectors) => + Annotation( + annotatorType = AnnotatorType.SENTENCE_EMBEDDINGS, + begin = sentence._1._2.begin, + end = sentence._1._2.end, + result = sentence._1._2.result, + metadata = sentence._1._2.metadata, + embeddings = vectors) + } + } + } + + private object OnnxSignatures { + val encoderInputIDs: String = "input_ids" + val encoderAttentionMask: String = "attention_mask" + + val encoderOutput: String = "sentence_embedding" + } +} diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddings.scala new file mode 100644 index 000000000000..1a34bdef4915 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddings.scala @@ -0,0 +1,422 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.ml.ai.Nomic +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.tensorflow._ +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} +import com.johnsnowlabs.nlp.serialization.MapFeature +import com.johnsnowlabs.storage.HasStorageRef +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.slf4j.{Logger, LoggerFactory} + +/** Sentence embeddings using NomicEmbeddings. + * + * NomicEmbeddings, an instruction-finetuned text embedding model that can generate text + * embeddings tailored to any task (e.g., classification, retrieval, clustering, text evaluation, + * etc.) + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val embeddings = NomicEmbeddings.pretrained() + * .setInputCols("document") + * .setOutputCol("nomic_embeddings") + * }}} + * The default model is `"nomic_small"`, if no name is provided. + * + * For available pretrained models please see the + * [[https://sparknlp.org/models?q=NomicEmbeddings Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala NomicEmbeddingsTestSpec]]. + * + * '''Sources''' : + * + * [[https://arxiv.org/pdf/2212.03533 Text Embeddings by Weakly-Supervised Contrastive Pre-training]] + * + * [[https://github.com/microsoft/unilm/tree/master/nomic NomicEmbeddings Github Repository]] + * + * ''' Paper abstract ''' + * + * ''This paper presents NomicEmbeddings, a family of state-of-the-art text embeddings that + * transfer well to a wide range of tasks. The model is trained in a contrastive manner with weak + * supervision signals from our curated large-scale text pair dataset (called CCPairs). + * NomicEmbeddings can be readily used as a general-purpose embedding model for any tasks + * requiring a single-vector representation of texts such as retrieval, clustering, and + * classification, achieving strong performance in both zero-shot and fine-tuned settings. We + * conduct extensive evaluations on 56 datasets from the BEIR and MTEB benchmarks. For zero-shot + * settings, NomicEmbeddings is the first model that outperforms the strong BM25 baseline on the + * BEIR retrieval benchmark without using any labeled data. When fine-tuned, NomicEmbeddings + * obtains the best results on the MTEB benchmark, beating existing embedding models with 40× + * more parameters.'' + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.Tokenizer + * import com.johnsnowlabs.nlp.embeddings.NomicEmbeddings + * import com.johnsnowlabs.nlp.EmbeddingsFinisher + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("document") + * + * val embeddings = NomicEmbeddings.pretrained("nomic_small", "en") + * .setInputCols("document") + * .setOutputCol("nomic_embeddings") + * + * val embeddingsFinisher = new EmbeddingsFinisher() + * .setInputCols("nomic_embeddings") + * .setOutputCols("finished_embeddings") + * .setOutputAsVector(true) + * + * val pipeline = new Pipeline().setStages(Array( + * documentAssembler, + * embeddings, + * embeddingsFinisher + * )) + * + * val data = Seq("query: how much protein should a female eat", + * "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day." + + * But, as you can see from this chart, you'll need to increase that if you're expecting or training for a" + + * marathon. Check out the chart below to see how much protein you should be eating each day." + * + * ).toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * result.selectExpr("explode(finished_embeddings) as result").show(1, 80) + * +--------------------------------------------------------------------------------+ + * | result| + * +--------------------------------------------------------------------------------+ + * |[[8.0190285E-4, -0.005974853, -0.072875895, 0.007944068, 0.026059335, -0.0080...| + * [[0.050514214, 0.010061974, -0.04340176, -0.020937217, 0.05170225, 0.01157857...| + * +--------------------------------------------------------------------------------+ + * }}} + * + * @see + * [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer + * based embeddings + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class NomicEmbeddings(override val uid: String) + extends AnnotatorModel[NomicEmbeddings] + with HasBatchedAnnotate[NomicEmbeddings] + with WriteTensorflowModel + with WriteOnnxModel + with HasEmbeddingsProperties + with HasStorageRef + with HasCaseSensitiveProperties + with HasEngine { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + override val inputAnnotatorTypes: Array[String] = + Array(AnnotatorType.DOCUMENT) + override val outputAnnotatorType: AnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + + /** ConfigProto from tensorflow, serialized into byte array. Get with + * `config_proto.SerializeToString()` + * + * @group param + */ + val configProtoBytes = new IntArrayParam( + this, + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()") + + /** Max sentence length to process (Default: `128`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + def sentenceStartTokenId: Int = { + $$(vocabulary)("[CLS]") + } + + /** @group setParam */ + def sentenceEndTokenId: Int = { + $$(vocabulary)("[SEP]") + } + + /** Vocabulary used to encode the words to ids with WordPieceEncoder + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + private var _model: Option[Broadcast[Nomic]] = None + + def this() = this(Identifiable.randomUID("NOMIC_EMBEDDINGS")) + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "NomicEmbeddings models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + /** @group setParam */ + def setModelIfNotSet(spark: SparkSession, onnxWrapper: Option[OnnxWrapper]): NomicEmbeddings = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new Nomic( + onnxWrapper, + sentenceStartTokenId = sentenceStartTokenId, + sentenceEndTokenId = sentenceEndTokenId))) + } + + this + } + + /** Set Embeddings dimensions for the BERT model Only possible to set this when the first time + * is saved dimension is not changeable, it comes from BERT config file + * + * @group setParam + */ + override def setDimension(value: Int): this.type = { + if (get(dimension).isEmpty) + set(this.dimension, value) + this + } + + /** Whether to lowercase tokens or not + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = { + if (get(caseSensitive).isEmpty) + set(this.caseSensitive, value) + this + } + + setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> false) + + def tokenize(sentences: Seq[Annotation]): Seq[WordpieceTokenizedSentence] = { + val basicTokenizer = new BasicTokenizer($(caseSensitive)) + val encoder = new WordpieceEncoder($$(vocabulary)) + sentences.map { s => + val sent = Sentence( + content = s.result, + start = s.begin, + end = s.end, + metadata = Some(s.metadata), + index = s.begin) + val tokens = basicTokenizer.tokenize(sent) + val wordpieceTokens = tokens.flatMap(token => encoder.encode(token)) + WordpieceTokenizedSentence(wordpieceTokens) + } + } + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + + // Tokenize sentences + val tokenizedSentences = tokenize(allAnnotations.map(_._1)) + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + tokenizedSentences = tokenizedSentences, + batchSize = $(batchSize), + maxSentenceLength = $(maxSentenceLength)) + } else { + Seq() + } + + // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence + batchedAnnotations.indices.map(rowIndex => { + val rowAnnotations = processedAnnotations + // zip each annotation with its corresponding row index + .zip(allAnnotations) + // select the sentences belonging to the current row + .filter(_._2._2 == rowIndex) + // leave the annotation only + .map(_._1) + + if (rowAnnotations.nonEmpty) + rowAnnotations + else + Seq.empty[Annotation] + }) + + } + + /** @group getParam */ + def getModelIfNotSet: Nomic = _model.get.value + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + val suffix = "_nomic" + + getEngine match { + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + NomicEmbeddings.onnxFile) + + case _ => + throw new Exception(notSupportedEngineError) + } + } + + override protected def afterAnnotate(dataset: DataFrame): DataFrame = { + dataset.withColumn( + getOutputCol, + wrapSentenceEmbeddingsMetadata( + dataset.col(getOutputCol), + $(dimension), + Some($(storageRef)))) + } + +} + +trait ReadablePretrainedNomicEmbeddingsModel + extends ParamsAndFeaturesReadable[NomicEmbeddings] + with HasPretrained[NomicEmbeddings] { + override val defaultModelName: Some[String] = Some("nomic_small") + + /** Java compliant-overrides */ + override def pretrained(): NomicEmbeddings = super.pretrained() + + override def pretrained(name: String): NomicEmbeddings = super.pretrained(name) + + override def pretrained(name: String, lang: String): NomicEmbeddings = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): NomicEmbeddings = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadNomicEmbeddingsDLModel extends ReadTensorflowModel with ReadOnnxModel { + this: ParamsAndFeaturesReadable[NomicEmbeddings] => + + override val tfFile: String = "nomic_tensorflow" + override val onnxFile: String = "nomic_onnx" + + def readModel(instance: NomicEmbeddings, path: String, spark: SparkSession): Unit = { + + instance.getEngine match { + case ONNX.name => + val onnxWrapper = + readOnnxModel(path, spark, "_nomic_onnx", zipped = true, useBundle = false, None) + instance.setModelIfNotSet(spark, Some(onnxWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): NomicEmbeddings = { + + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + + /*Universal parameters for all engines*/ + val annotatorModel = new NomicEmbeddings() + .setVocabulary(vocabs) + + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, Some(onnxWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } +} + +/** This is the companion object of [[NomicEmbeddings]]. Please refer to that class for the + * documentation. + */ +object NomicEmbeddings + extends ReadablePretrainedNomicEmbeddingsModel + with ReadNomicEmbeddingsDLModel { + private[NomicEmbeddings] val logger: Logger = + LoggerFactory.getLogger("NomicEmbeddings") +} diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala new file mode 100644 index 000000000000..1e10b7f883df --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala @@ -0,0 +1,160 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.{SlowTest, SlowTest} +import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.functions.{col, size} +import org.scalatest.flatspec.AnyFlatSpec + +class NomicEmbeddingsTestSpec extends AnyFlatSpec { + + "Nomic Embeddings" should "correctly embed multiple sentences" taggedAs SlowTest in { + + import ResourceHelper.spark.implicits._ + + val ddd = Seq( + "query: how much protein should a female eat", + "query: summit define", + "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 " + + "grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or" + + " training for a marathon. Check out the chart below to see how much protein you should be eating each day.", + "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of" + + " a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more" + + " governments.") + .toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = NomicEmbeddings + .pretrained() + .setInputCols(Array("document")) + .setOutputCol("nomic") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(ddd).transform(ddd) + pipelineDF.select("nomic.embeddings").show(truncate = false) + + } + + it should "have embeddings of the same size" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val testDf = Seq( + "I like apples", + "I like bananas \\n and other things \\n like icream \\n and cats", + "I like rockets") + .toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = NomicEmbeddings + .pretrained() + .setInputCols(Array("document")) + .setOutputCol("nomic") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + + val embeddingsDF = pipelineDF.withColumn("embeddings", col("nomic.embeddings").getItem(0)) + + val sizesArray: Array[Int] = embeddingsDF + .select(size(col("embeddings")).as("size")) + .collect() + .map(row => row.getAs[Int]("size")) + + assert(sizesArray.forall(_ == sizesArray.head)) + } + + it should "work with sentences" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val testData = "I really enjoy my job. This is amazing" + val testDf = Seq(testData).toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val sentenceDetectorDL = SentenceDetectorDLModel + .pretrained("sentence_detector_dl", "en") + .setInputCols(Array("document")) + .setOutputCol("sentences") + + val embeddings = NomicEmbeddings + .pretrained() + .setInputCols(Array("sentences")) + .setOutputCol("nomic") + + val pipeline = new Pipeline().setStages(Array(document, sentenceDetectorDL, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + pipelineDF.select("nomic.embeddings").show(false) + } + + it should "not return empty embeddings" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val interests = Seq( + "I like music", + "I like movies", + "I like books", + "I like sports", + "I like travel", + "I like food", + "I like games", + "I like art", + "I like nature", + "I like science", + "I like technology", + "I like history", + "I like fashion", + "I like cars", + "I like animals", + "I like gardening") + val testDf = interests.toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = NomicEmbeddings + .pretrained() + .setInputCols(Array("document")) + .setOutputCol("nomic") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + + val embeddingsDF = pipelineDF.withColumn("embeddings", col("nomic.embeddings").getItem(0)) + + val sizesArray: Array[Int] = embeddingsDF + .select(size(col("embeddings")).as("size")) + .collect() + .map(row => row.getAs[Int]("size")) + + assert(sizesArray.forall(_ > 0)) + } + +} From fe5537a8ce9b46d5e81146d9bd860481f8bdb8a8 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 27 Mar 2024 12:06:02 +0000 Subject: [PATCH 2/2] added Nomic python api and tests --- .../sparknlp/annotator/embeddings/__init__.py | 1 + .../annotator/embeddings/nomic_embeddings.py | 181 ++++++++++++++++++ python/sparknlp/internal/__init__.py | 3 + .../embeddings/nomic_embeddings_test.py | 56 ++++++ .../nlp/embeddings/NomicEmbeddings.scala | 31 ++- .../embeddings/NomicEmbeddingsTestSpec.scala | 2 +- 6 files changed, 255 insertions(+), 19 deletions(-) create mode 100644 python/sparknlp/annotator/embeddings/nomic_embeddings.py create mode 100644 python/test/annotator/embeddings/nomic_embeddings_test.py diff --git a/python/sparknlp/annotator/embeddings/__init__.py b/python/sparknlp/annotator/embeddings/__init__.py index 1ddf7952558d..c3ed901bae48 100644 --- a/python/sparknlp/annotator/embeddings/__init__.py +++ b/python/sparknlp/annotator/embeddings/__init__.py @@ -36,3 +36,4 @@ from sparknlp.annotator.embeddings.xlm_roberta_sentence_embeddings import * from sparknlp.annotator.embeddings.xlnet_embeddings import * from sparknlp.annotator.embeddings.bge_embeddings import * +from sparknlp.annotator.embeddings.nomic_embeddings import * diff --git a/python/sparknlp/annotator/embeddings/nomic_embeddings.py b/python/sparknlp/annotator/embeddings/nomic_embeddings.py new file mode 100644 index 000000000000..8cd303e4c81d --- /dev/null +++ b/python/sparknlp/annotator/embeddings/nomic_embeddings.py @@ -0,0 +1,181 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for E5Embeddings.""" + +from sparknlp.common import * + + +class NomicEmbeddings(AnnotatorModel, HasEmbeddingsProperties, HasCaseSensitiveProperties, HasStorageRef, + HasBatchedAnnotate, HasMaxSentenceLengthLimit): + """Sentence embeddings using NomicEmbeddings. + + nomic-embed-text-v1 is 8192 context length text encoder that surpasses OpenAI + text-embedding-ada-002 and text-embedding-3-small performance on short and long context tasks. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> embeddings = NomicEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("nomic_embeddings") + + + The default model is ``"nomic_small"``, if no name is provided. + + For available pretrained models please see the + `Models Hub `__. + + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``SENTENCE_EMBEDDINGS`` + ====================== ====================== + + Parameters + ---------- + batchSize + Size of every batch , by default 8 + dimension + Number of embedding dimensions, by default 768 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default False + maxSentenceLength + Max sentence length to process, by default 512 + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + + References + ---------- + `Text Embeddings by Weakly-Supervised Contrastive Pre-training `__ + + https://github.com/microsoft/unilm/tree/master/nomic + + **Paper abstract** + + *This technical report describes the training + of nomic-embed-text-v1, the first fully reproducible, + open-source, open-weights, opendata, 8192 context length + English text embedding model that outperforms both OpenAI + Ada-002 and OpenAI text-embedding-3-small + on short and long-context tasks. We release + the training code and model weights under + an Apache 2 license. In contrast with other + open-source models, we release a training data + loader with 235 million curated text pairs that + allows for the full replication of nomic-embedtext-v1. + You can find code and data to replicate the + model at https://github.com/nomicai/contrastors.* + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> embeddings = NomicEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("nomic_embeddings") + >>> embeddingsFinisher = EmbeddingsFinisher() \\ + ... .setInputCols(["nomic_embeddings"]) \\ + ... .setOutputCols("finished_embeddings") \\ + ... .setOutputAsVector(True) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... embeddings, + ... embeddingsFinisher + ... ]) + >>> data = spark.createDataFrame([["query: how much protein should a female eat", + ... "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day." + \ + ... "But, as you can see from this chart, you'll need to increase that if you're expecting or training for a" + \ + ... "marathon. Check out the chart below to see how much protein you should be eating each day.", + ... ]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.selectExpr("explode(finished_embeddings) as result").show(5, 80) + +--------------------------------------------------------------------------------+ + | result| + +--------------------------------------------------------------------------------+ + |[[8.0190285E-4, -0.005974853, -0.072875895, 0.007944068, 0.026059335, -0.0080...| + |[[0.050514214, 0.010061974, -0.04340176, -0.020937217, 0.05170225, 0.01157857...| + +--------------------------------------------------------------------------------+ + """ + + name = "NomicEmbeddings" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + configProtoBytes = Param(Params._dummy(), "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.NomicEmbeddings", java_model=None): + super(NomicEmbeddings, self).__init__(classname=classname, java_model=java_model) + self._setDefault(dimension=768, batchSize=8, maxSentenceLength=512, caseSensitive=False, ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + NomicEmbeddings + The restored model + """ + from sparknlp.internal import _NomicLoader + jModel = _NomicLoader(folder, spark_session._jsparkSession)._java_obj + return NomicEmbeddings(java_model=jModel) + + @staticmethod + def pretrained(name="nomic_small", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "nomic_small" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + NomicEmbeddings + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(NomicEmbeddings, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index c1aabeeb36ae..8c89ac8936e9 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -251,6 +251,9 @@ def __init__(self, path, jspark, useCache): super(_BartLoader, self).__init__( "com.johnsnowlabs.nlp.annotators.seq2seq.BartTransformer.loadSavedModel", path, jspark, useCache) +class _NomicLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_NomicLoader, self).__init__("com.johnsnowlabs.nlp.embeddings.NomicEmbeddings.loadSavedModel", path, jspark) class _USELoader(ExtendedJavaWrapper): def __init__(self, path, jspark, loadsp): diff --git a/python/test/annotator/embeddings/nomic_embeddings_test.py b/python/test/annotator/embeddings/nomic_embeddings_test.py new file mode 100644 index 000000000000..e1da0e63626b --- /dev/null +++ b/python/test/annotator/embeddings/nomic_embeddings_test.py @@ -0,0 +1,56 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.annotator.common.has_max_sentence_length_test import HasMaxSentenceLengthTests +from test.util import SparkContextForTest + + +@pytest.mark.slow +class NomicEmbeddingsTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.tested_annotator = NomicEmbeddings \ + .pretrained() \ + .setInputCols(["documents"]) \ + .setOutputCol("nomic") + + def runTest(self): + data = self.spark.createDataFrame([ + [1, "query: how much protein should a female eat"], + [2, "query: summit define"], + [3, "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 " + "is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're " + "expecting or training for a marathon. Check out the chart below to see how much protein you should " + "be eating each day.", ], + [4, "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain :" + " the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the " + "leaders of two or more governments."] + ]).toDF("id", "text") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + nomic = self.tested_annotator + + pipeline = Pipeline().setStages([document_assembler, nomic]) + results = pipeline.fit(data).transform(data) + + results.select("nomic.embeddings").show(truncate=False) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddings.scala index 1a34bdef4915..dc680be3406d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddings.scala @@ -38,9 +38,8 @@ import org.slf4j.{Logger, LoggerFactory} /** Sentence embeddings using NomicEmbeddings. * - * NomicEmbeddings, an instruction-finetuned text embedding model that can generate text - * embeddings tailored to any task (e.g., classification, retrieval, clustering, text evaluation, - * etc.) + * nomic-embed-text-v1 is 8192 context length text encoder that surpasses OpenAI + * text-embedding-ada-002 and text-embedding-3-small performance on short and long context tasks. * * Pretrained models can be loaded with `pretrained` of the companion object: * {{{ @@ -58,23 +57,19 @@ import org.slf4j.{Logger, LoggerFactory} * * '''Sources''' : * - * [[https://arxiv.org/pdf/2212.03533 Text Embeddings by Weakly-Supervised Contrastive Pre-training]] + * [[https://static.nomic.ai/reports/2024_Nomic_Embed_Text_Technical_Report.pdf Nomic Embed: Training a Reproducible Long Context Text Embedder]] * - * [[https://github.com/microsoft/unilm/tree/master/nomic NomicEmbeddings Github Repository]] + * [[https://github.com/nomicai/contrastors NomicEmbeddings Github Repository]] * * ''' Paper abstract ''' * - * ''This paper presents NomicEmbeddings, a family of state-of-the-art text embeddings that - * transfer well to a wide range of tasks. The model is trained in a contrastive manner with weak - * supervision signals from our curated large-scale text pair dataset (called CCPairs). - * NomicEmbeddings can be readily used as a general-purpose embedding model for any tasks - * requiring a single-vector representation of texts such as retrieval, clustering, and - * classification, achieving strong performance in both zero-shot and fine-tuned settings. We - * conduct extensive evaluations on 56 datasets from the BEIR and MTEB benchmarks. For zero-shot - * settings, NomicEmbeddings is the first model that outperforms the strong BM25 baseline on the - * BEIR retrieval benchmark without using any labeled data. When fine-tuned, NomicEmbeddings - * obtains the best results on the MTEB benchmark, beating existing embedding models with 40× - * more parameters.'' + * ''This technical report describes the training of nomic-embed-text-v1, the first fully + * reproducible, open-source, open-weights, opendata, 8192 context length English text embedding + * model that outperforms both OpenAI Ada-002 and OpenAI text-embedding-3-small on short and + * long-context tasks. We release the training code and model weights under an Apache 2 license. + * In contrast with other open-source models, we release a training data loader with 235 million + * curated text pairs that allows for the full replication of nomic-embedtext-v1. You can find + * code and data to replicate the model at https://github.com/nomicai/contrastors.'' * * ==Example== * {{{ @@ -202,8 +197,8 @@ class NomicEmbeddings(override val uid: String) /** @group setParam */ def setMaxSentenceLength(value: Int): this.type = { require( - value <= 512, - "NomicEmbeddings models do not support sequences longer than 512 because of trainable positional embeddings.") + value <= 8192, + "NomicEmbeddings models do not support sequences longer than 8192 because of trainable positional embeddings.") require(value >= 1, "The maxSentenceLength must be at least 1") set(maxSentenceLength, value) this diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala index 1e10b7f883df..65ccc0c6e8ab 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/NomicEmbeddingsTestSpec.scala @@ -19,7 +19,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel import com.johnsnowlabs.nlp.base.DocumentAssembler import com.johnsnowlabs.nlp.util.io.ResourceHelper -import com.johnsnowlabs.tags.{SlowTest, SlowTest} +import com.johnsnowlabs.tags.SlowTest import org.apache.spark.ml.Pipeline import org.apache.spark.sql.functions.{col, size} import org.scalatest.flatspec.AnyFlatSpec