Skip to content

Commit f5161ed

Browse files
committed
Add option for toggling CUDA, change library package
1 parent 28b8565 commit f5161ed

File tree

7 files changed

+12
-12
lines changed

7 files changed

+12
-12
lines changed

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ subprojects {
66
apply(plugin = "java")
77
apply(plugin = "maven-publish")
88

9-
group = "xyz.bluspring"
9+
group = "xyz.bluspring.unitytranslate"
1010
version = "${rootProject.property("unitytranslate_version")}"
1111

1212
repositories {

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
unitytranslate_version = 0.2.0-SNAPSHOT
1+
unitytranslate_version = 0.2.1
22

33
# This isn't actually used in the download process, however it's used for the sake of caching the downloaded files.
44
# https://github.com/OpenNMT/CTranslate2/blob/master/python/ctranslate2/version.py

library/src/main/kotlin/xyz/bluspring/unitytranslate/library/UnityTranslateLib.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class UnityTranslateLib(val path: Path) {
2727
packageIndex.load()
2828
}
2929

30-
private suspend fun createTranslator(code: String): Translator {
30+
private suspend fun createTranslator(code: String, useCuda: Boolean = false): Translator {
3131
val split = code.split("_")
3232

3333
val translator = if (split[0] == split[1])
@@ -36,7 +36,7 @@ class UnityTranslateLib(val path: Path) {
3636
else
3737
ModelBasedTranslator(this, code)
3838

39-
translator.load()
39+
translator.load(useCuda)
4040

4141
return translator
4242
}

library/src/main/kotlin/xyz/bluspring/unitytranslate/library/models/ModelPackageManager.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ class ModelPackageManager(val library: UnityTranslateLib) {
4242
return mapOf()
4343
}
4444

45-
suspend fun tryLoadModels(code: String): Map<String, Long> {
45+
suspend fun tryLoadModels(code: String, useCuda: Boolean): Map<String, Long> {
4646
val split = code.split("_")
4747
val fromCode = split[0]
4848
val toCode = split[1]
49-
return tryLoadModels(fromCode, toCode)
49+
return tryLoadModels(fromCode, toCode, useCuda)
5050
}
5151

52-
suspend fun tryLoadModels(fromCode: String, toCode: String): Map<String, Long> {
52+
suspend fun tryLoadModels(fromCode: String, toCode: String, useCuda: Boolean): Map<String, Long> {
5353
val modelInfos = this.getModelInfos(fromCode, toCode)
5454

5555
if (modelInfos.isEmpty())
@@ -65,7 +65,7 @@ class ModelPackageManager(val library: UnityTranslateLib) {
6565
}
6666

6767
infos.toList().asFlow().concurrent().collect { (pkg, modelInfo) ->
68-
val modelPtr = library.loadModel(modelInfo.modelPath.absolutePathString(), modelInfo.spModelPath?.absolutePathString(), modelInfo.bpeModelPath?.absolutePathString(), false)
68+
val modelPtr = library.loadModel(modelInfo.modelPath.absolutePathString(), modelInfo.spModelPath?.absolutePathString(), modelInfo.bpeModelPath?.absolutePathString(), useCuda)
6969

7070
if (modelPtr != 0L) {
7171
loadedModelPtrs[pkg.code] = modelPtr

library/src/main/kotlin/xyz/bluspring/unitytranslate/library/translator/DummyTranslator.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package xyz.bluspring.unitytranslate.library.translator
33
import xyz.bluspring.unitytranslate.library.UnityTranslateLib
44

55
class DummyTranslator(library: UnityTranslateLib, code: String) : Translator(library, code) {
6-
override suspend fun load() {
6+
override suspend fun load(useCuda: Boolean) {
77
}
88

99
override fun batchTranslate(texts: List<String>): List<String> {

library/src/main/kotlin/xyz/bluspring/unitytranslate/library/translator/ModelBasedTranslator.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ class ModelBasedTranslator(library: UnityTranslateLib, code: String) : Translato
77
var isReady = false
88
private set
99

10-
override suspend fun load() {
10+
override suspend fun load(useCuda: Boolean) {
1111
if (isReady)
1212
return
1313

14-
modelPtrs = library.packageIndex.tryLoadModels(code)
14+
modelPtrs = library.packageIndex.tryLoadModels(code, useCuda)
1515
isReady = true
1616
}
1717

library/src/main/kotlin/xyz/bluspring/unitytranslate/library/translator/Translator.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ package xyz.bluspring.unitytranslate.library.translator
33
import xyz.bluspring.unitytranslate.library.UnityTranslateLib
44

55
abstract class Translator(protected val library: UnityTranslateLib, protected val code: String) {
6-
abstract suspend fun load()
6+
abstract suspend fun load(useCuda: Boolean)
77
abstract fun batchTranslate(texts: List<String>): List<String>
88
}

0 commit comments

Comments
 (0)