Skip to content

Commit

Permalink
set python tests to transformers 4.35 (#1551)
Browse files Browse the repository at this point in the history
* python tests transformers 4.35
  • Loading branch information
vince62s committed Nov 20, 2023
1 parent e52b295 commit 46f57e2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers==4.29.*;platform_system=='Linux'
transformers==4.35.*;platform_system=='Linux'
fairseq==0.12.2;platform_system=='Linux' or platform_system=='Darwin'
OpenNMT-py==2.2.*;platform_system=='Linux' or platform_system=='Darwin'
OpenNMT-tf==2.30.*
Expand Down
8 changes: 6 additions & 2 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,13 +984,17 @@ def test_transformers_wav2vec2(
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(model_name)
del w2v2_model.wav2vec2.encoder.layers
del w2v2_model.wav2vec2.encoder.layer_norm
torch.save(w2v2_model, output_dir + "/wav2vec2_partial.bin")
w2v2_model.save_pretrained(output_dir + "/wav2vec2_partial.bin")
w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name)
torch.save(w2v2_processor, output_dir + "/wav2vec2_processor.bin")

device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0))
w2v2_model = torch.load(output_dir + "/wav2vec2_partial.bin").to(device)
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(
output_dir + "/wav2vec2_partial.bin"
).to(device)
del w2v2_model.wav2vec2.encoder.layers
del w2v2_model.wav2vec2.encoder.layer_norm
w2v2_processor = torch.load(output_dir + "/wav2vec2_processor.bin")
ct2_w2v2_model = ctranslate2.models.Wav2Vec2(
output_dir,
Expand Down

0 comments on commit 46f57e2

Please sign in to comment.