Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

M2M100Transformer Integration #254

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Empty file.
14 changes: 14 additions & 0 deletions nlu/components/seq2seqs/m2m100_transformer/m2m100_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from sparknlp.annotator import *

class M2M100:
@staticmethod
def get_default_model():
return M2M100Transformer.pretrained() \
.setInputCols("document") \
.setOutputCol("generation")

@staticmethod
def get_pretrained_model(name, language, bucket=None):
return M2M100Transformer.pretrained(name, language, bucket) \
.setInputCols("document") \
.setOutputCol("generation")
2 changes: 2 additions & 0 deletions nlu/spellbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9196,6 +9196,7 @@ class Spellbook:
'xx.embed_sentence.bert_use_cmlm_multi_base_br': 'sent_bert_use_cmlm_multi_base_br',
'xx.embed_sentence.labse': 'labse',
'xx.embed_sentence.xlm_roberta.base': 'sent_xlm_roberta_base',
'xx.m2m100_418M': 'm2m100_418M',
'xx.en.marian.translate_to.aav': 'opus_mt_en_aav',
'xx.en.marian.translate_to.af': 'opus_mt_en_af',
'xx.en.marian.translate_to.afa': 'opus_mt_en_afa',
Expand Down Expand Up @@ -16876,6 +16877,7 @@ class Spellbook:
'onto_small_bert_L8_512': 'NerDLModel',
'openai.completion': 'OpenAICompletion',
'openai.embeddings': 'OpenAIEmbeddings',
'm2m100_418M': 'M2M100Transformer',
'opus_mt_aav_en': 'MarianTransformer',
'opus_mt_aed_es': 'MarianTransformer',
'opus_mt_af_de': 'MarianTransformer',
Expand Down
2 changes: 1 addition & 1 deletion nlu/universe/annotator_class_universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class AnnoClassRef:
A_N.DEBERTA_FOR_TOKEN_CLASSIFICATION: 'DeBertaForTokenClassification',
A_N.CAMEMBERT_EMBEDDINGS: 'CamemBertEmbeddings',
A_N.BART_TRANSFORMER: 'BartTransformer',

A_N.M2M100_TRANSFORMER: 'M2M100Transformer',
A_N.TRAINABLE_VIVEKN_SENTIMENT: 'ViveknSentimentApproach',
A_N.TRAINABLE_SENTIMENT: 'SentimentDetector',
A_N.TRAINABLE_SENTIMENT_DL: 'SentimentDLApproach',
Expand Down
21 changes: 21 additions & 0 deletions nlu/universe/component_universes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
from nlu.components.sentence_detectors.deep_sentence_detector.deep_sentence_detector import SentenceDetectorDeep
from nlu.components.sentence_detectors.pragmatic_sentence_detector.sentence_detector import PragmaticSentenceDetector
from nlu.components.seq2seqs.bart_transformer.bart_transformer import SparkNLPBartTransformer
from nlu.components.seq2seqs.m2m100_transformer.m2m100_transformer import M2M100
from nlu.components.seq2seqs.gpt2.gpt2 import GPT2
from nlu.components.seq2seqs.openai_completion.openai_completion import OpenaiCompletion
from nlu.components.embeddings.openai_embeddings.openai_embeddings import OpenaiEmbeddings
Expand Down Expand Up @@ -2814,6 +2815,26 @@ class ComponentUniverse:
jsl_anno_py_class=ACR.JSL_anno2_py_class[A.BART_TRANSFORMER],
),

A.M2M100_TRANSFORMER: partial(NluComponent,
name=A.M2M100_TRANSFORMER,
type=T.DOCUMENT_CLASSIFIER,
get_default_model=M2M100.get_default_model,
get_pretrained_model=M2M100.get_pretrained_model,
pdf_extractor_methods={'default': default_gpt2_config,
'default_full': default_full_config, },
pdf_col_name_substitutor=substitute_gpt2_cols,
output_level=L.INPUT_DEPENDENT_DOCUMENT_CLASSIFIER,
node=NLP_FEATURE_NODES.nodes[A.M2M100_TRANSFORMER],
description='Bart Transformer',
provider=ComponentBackends.open_source,
license=Licenses.open_source,
computation_context=ComputeContexts.spark,
output_context=ComputeContexts.spark,
jsl_anno_class_id=A.M2M100_TRANSFORMER,
jsl_anno_py_class=ACR.JSL_anno2_py_class[A.M2M100_TRANSFORMER],
),


H_A.MEDICAL_TEXT_GENERATOR: partial(NluComponent,
name=H_A.MEDICAL_TEXT_GENERATOR,
type=T.DOCUMENT_CLASSIFIER,
Expand Down
3 changes: 3 additions & 0 deletions nlu/universe/feature_node_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class NLP_NODE_IDS:
CONVNEXT_IMAGE_CLASSIFICATION = JslAnnoId("convnext_image_classification")
SWIN_IMAGE_CLASSIFICATION = JslAnnoId("swin_image_classification")
BART_TRANSFORMER = JslAnnoId("bart_transformer")

M2M100_TRANSFORMER = JslAnnoId("m2m100_transformer")

INSTRUCTOR_SENTENCE_EMBEDDINGS = JslAnnoId('instructor_sentence_embeddings')

MPNET_SENTENCE_EMBEDDINGS = JslAnnoId('mpnet_sentence_embeddings')
Expand Down
1 change: 1 addition & 0 deletions nlu/universe/feature_node_universes.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ class NLP_FEATURE_NODES: # or Mode Node?
A.CONVNEXT_IMAGE_CLASSIFICATION: NlpFeatureNode(A.CONVNEXT_IMAGE_CLASSIFICATION, [F.IMAGE], [F.CLASSIFIED_IMAGE]),
A.SWIN_IMAGE_CLASSIFICATION: NlpFeatureNode(A.SWIN_IMAGE_CLASSIFICATION, [F.IMAGE], [F.CLASSIFIED_IMAGE]),
A.BART_TRANSFORMER: NlpFeatureNode(A.BART_TRANSFORMER, [F.DOCUMENT], [F.DOCUMENT_GENERATED]),
A.M2M100_TRANSFORMER: NlpFeatureNode(A.M2M100_TRANSFORMER, [F.DOCUMENT], [F.DOCUMENT_GENERATED]),

}

Expand Down
14 changes: 14 additions & 0 deletions tests/nlu_core_tests/component_tests/seq2seq/m2m100_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os
import sys
import unittest
from nlu import *


class M2M100TransformerTests(unittest.TestCase):
def test_m2m100_transformer(self):
model = nlu.load("xx.m2m100_418M")
df = model.predict("生活就像一盒巧克力。")
print(df.columns)

if __name__ == "__main__":
unittest.main()
Loading