Skip to content

Commit

Permalink
Reimplmenting Keras directly via reticulate
Browse files Browse the repository at this point in the history
  • Loading branch information
FBerding committed Jul 9, 2023
1 parent 5ab2580 commit c5870df
Show file tree
Hide file tree
Showing 25 changed files with 59,268 additions and 79,599 deletions.
716 changes: 358 additions & 358 deletions .Rhistory

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
#)
conda_install(
packages = "tensorflow<2.11",
packages = "tensorflow",
envname = envname,
conda = "auto",
pip = TRUE
Expand Down
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ RoxygenNote: 7.2.3
Encoding: UTF-8
Depends:
quanteda,
keras,
foreach,
R (>= 2.10)
Suggests:
Expand Down
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import(Rcpp)
import(bundle)
import(doParallel)
import(foreach)
import(keras)
import(quanteda)
import(reshape2)
import(reticulate)
Expand Down
2 changes: 1 addition & 1 deletion R/install_and_config.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ install_py_modules<-function(envname="aifeducation"){
)

reticulate::conda_install(
packages = c("tensorflow<2.11"),
packages = c("tensorflow"),
envname = envname,
conda = "auto",
pip = TRUE
Expand Down
143 changes: 65 additions & 78 deletions R/te_classifier_neuralnet_model.R

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions R/transformer_bert.R
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ train_tune_bert_model=function(output_dir,
dir.create(paste0(output_dir,"/checkpoints"))
}
callback_checkpoint=tf$keras$callbacks$ModelCheckpoint(
filepath = paste0(output_dir,"/checkpoints/best_weights.cpkt"),
filepath = paste0(output_dir,"/checkpoints/best_weights.h5"),
monitor="val_loss",
verbose=1L,
mode="auto",
Expand All @@ -352,7 +352,7 @@ train_tune_bert_model=function(output_dir,
callbacks=list(callback_checkpoint))

cat(paste(date(),"Load Weights From Best Checkpoint"))
mlm_model$load_weights(paste0(output_dir,"/checkpoints/best_weights.cpkt"))
mlm_model$load_weights(paste0(output_dir,"/checkpoints/best_weights.h5"))

cat(paste(date(),"Saving Bert Model"))
mlm_model$save_pretrained(save_directory=output_dir)
Expand Down
4 changes: 2 additions & 2 deletions R/transformer_longformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ train_tune_longformer_model=function(output_dir,
}

callback_checkpoint=tf$keras$callbacks$ModelCheckpoint(
filepath = paste0(output_dir,"/checkpoints/best_weights.cpkt"),
filepath = paste0(output_dir,"/checkpoints/best_weights.h5"),
monitor="val_loss",
verbose=1L,
mode="auto",
Expand All @@ -299,7 +299,7 @@ train_tune_longformer_model=function(output_dir,
callbacks=list(callback_checkpoint))

cat(paste(date(),"Load Weights From Best Checkpoint"))
mlm_model$load_weights(paste0(output_dir,"/checkpoints/best_weights.cpkt"))
mlm_model$load_weights(paste0(output_dir,"/checkpoints/best_weights.h5"))

cat(paste(date(),"Saving Longformer Model"))
mlm_model$save_pretrained(save_directory=output_dir)
Expand Down
4 changes: 2 additions & 2 deletions R/transformer_roberta.R
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ train_tune_roberta_model=function(output_dir,
}

callback_checkpoint=tf$keras$callbacks$ModelCheckpoint(
filepath = paste0(output_dir,"/checkpoints/best_weights.cpkt"),
filepath = paste0(output_dir,"/checkpoints/best_weights.h5t"),
monitor="val_loss",
verbose=1L,
mode="auto",
Expand All @@ -304,7 +304,7 @@ train_tune_roberta_model=function(output_dir,
callbacks=list(callback_checkpoint))

print(paste(date(),"Load Weights From Best Checkpoint"))
mlm_model$load_weights(paste0(output_dir,"/checkpoints/best_weights.cpkt"))
mlm_model$load_weights(paste0(output_dir,"/checkpoints/best_weights.h5"))

print(paste(date(),"Saving RoBERTa Model"))
mlm_model$save_pretrained(save_directory=output_dir)
Expand Down
1 change: 1 addition & 0 deletions docs/articles/classification_tasks.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/deps/bootstrap-5.2.2/bootstrap.min.css

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/index.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ articles:
aifeducation: aifeducation.html
classification_tasks: classification_tasks.html
sharing_and_publishing: sharing_and_publishing.html
last_built: 2023-07-07T07:43Z
last_built: 2023-07-09T19:11Z

6 changes: 0 additions & 6 deletions docs/reference/TextEmbeddingClassifierNeuralNet.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/search.json

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions man/TextEmbeddingClassifierNeuralNet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 0 additions & 4 deletions tests/testthat/test-classifier_neural_net.R
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ test_that("training_baseline_only", {
batch_size=32,
dir_checkpoint=testthat::test_path("test_data/tmp"),
trace=FALSE,
view_metrics=FALSE,
keras_trace=0,
n_cores=1)
)
Expand Down Expand Up @@ -189,7 +188,6 @@ test_that("training_bsc_only", {
batch_size=32,
dir_checkpoint=testthat::test_path("test_data/tmp"),
trace=FALSE,
view_metrics=FALSE,
keras_trace=0,
n_cores=1)
)
Expand Down Expand Up @@ -222,7 +220,6 @@ test_that("training_pbl_baseline", {
batch_size=32,
dir_checkpoint=testthat::test_path("test_data/tmp"),
trace=FALSE,
view_metrics=FALSE,
keras_trace=0,
n_cores=1)
)
Expand Down Expand Up @@ -255,7 +252,6 @@ test_that("training_pbl_bsc", {
batch_size=32,
dir_checkpoint=testthat::test_path("test_data/tmp"),
trace=FALSE,
view_metrics=FALSE,
keras_trace=0,
n_cores=1)
)
Expand Down
5 changes: 2 additions & 3 deletions tests/testthat/test_data/bert/config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"_name_or_path": "test_data/bert",
"architectures": [
"BertForMaskedLM"
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
Expand All @@ -21,5 +20,5 @@
"transformers_version": "4.30.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 29729
"vocab_size": 30522
}
Loading

0 comments on commit c5870df

Please sign in to comment.