Skip to content

Commit

Permalink
Fix GitHub Actions
Browse files Browse the repository at this point in the history
Fix Gpu support of pooling for embedding texts
  • Loading branch information
FBerding committed Mar 14, 2024
1 parent 61db320 commit 9d480c3
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
conda_install(
packages = c(
"tensorflow-cpu",
"tf-keras",
"torch",
"torcheval",
"safetensors",
Expand Down
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ editor_options:
- Adapted the interface according to the changes made in this version.
- Improved the read of raw texts. Reading now reduces multiple spaces characters to
one single space character. Hyphenation is removed.

**Python Installation**

- Updated installation to account for the new version of keras.


# aifeducation 0.3.1
Expand Down Expand Up @@ -99,7 +103,7 @@ editor_options:
- Added an argument to 'install_py_modules',
allowing to choose which machine learning framework should be
installed.
- Updated 'check_aif_py_modules'.
- Updated 'check_aif_py_modules'.

**Further Changes**

Expand Down
4 changes: 2 additions & 2 deletions R/install_and_config.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ install_py_modules<-function(envname="aifeducation",
reticulate::conda_install(
packages = c(
"tensorflow-cpu",
"keras"),
"tf-keras"),
envname = envname,
conda = "auto",
pip = TRUE)
} else {
reticulate::conda_install(
packages = c(
paste0("tensorflow",tf_version),
"keras"),
"tf-keras"),
envname = envname,
conda = "auto",
pip = TRUE)
Expand Down
8 changes: 5 additions & 3 deletions R/text_embedding_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,6 @@ TextEmbeddingModel<-R6::R6Class(
}

for (b in 1:n_batches){

if(private$transformer_components$ml_framework=="pytorch"){
#Set model to evaluation mode
private$transformer_components$model$eval()
Expand All @@ -1237,6 +1236,9 @@ TextEmbeddingModel<-R6::R6Class(
pytorch_device="cpu"
}
private$transformer_components$model$to(pytorch_device)
if(private$transformer_components$emb_pool_type=="average"){
pooling$to(pytorch_device)
}
}

#tokens<-self$encode(raw_text = raw_text,
Expand Down Expand Up @@ -1304,8 +1306,8 @@ TextEmbeddingModel<-R6::R6Class(
if(private$transformer_components$emb_pool_type=="average"){
#Average Pooling over all tokens
for(i in tmp_selected_layer){
tensor_embeddings[i]=list(pooling(x=tensor_embeddings[[as.integer(i)]],
mask=tokens$encodings["attention_mask"]))
tensor_embeddings[i]=list(pooling(x=tensor_embeddings[[as.integer(i)]]$to(pytorch_device),
mask=tokens$encodings["attention_mask"]$to(pytorch_device)))
}
}

Expand Down

0 comments on commit 9d480c3

Please sign in to comment.