Skip to content

Commit

Permalink
Fix M2M100 not working on the second run [skip test]
Browse files Browse the repository at this point in the history
  • Loading branch information
maziyarpanahi committed Mar 2, 2024
1 parent ad5a4ea commit 75d398e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
14 changes: 6 additions & 8 deletions src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ private[johnsnowlabs] class M2M100(
applySoftmax = false)

// Run the prompt through the decoder and get the past
// val decoderOutputs =
// generateGreedyOnnx(
// decoderInputIds,
// decoderEncoderStateTensors,
// encoderAttentionMaskTensors,
// onnxSession = (decoderSession, decoderEnv))
// val decoderOutputs =
// generateGreedyOnnx(
// decoderInputIds,
// decoderEncoderStateTensors,
// encoderAttentionMaskTensors,
// onnxSession = (decoderSession, decoderEnv))

// close sessions
decoderEncoderStateTensors.fold(
Expand All @@ -216,8 +216,6 @@ private[johnsnowlabs] class M2M100(
},
onnxTensor => onnxTensor.close())

encoderSession.close()
decoderSession.close()
encoderEnv.close()
decoderEnv.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class LLAMA2TestSpec extends AnyFlatSpec {

val pipelineModel = pipeline.fit(testData)

pipelineModel
.transform(testData)
.show(truncate = false)

pipelineModel
.transform(testData)
.show(truncate = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,15 @@ class M2M100TestSpec extends AnyFlatSpec {
.setOutputCol("generation")
.setBeamSize(1)

new Pipeline()
val pipeline = new Pipeline()
.setStages(Array(documentAssembler, m2m100))
.fit(testData)
.transform(testData)
.show(truncate = false)

val pipelineModel = pipeline.fit(testData)

val result = pipelineModel.transform(testData)

result.show(truncate = false)
result.show(truncate = false)

}

Expand Down

0 comments on commit 75d398e

Please sign in to comment.