From ce8e3f1499a90c05730db9041633b24b1272d8c6 Mon Sep 17 00:00:00 2001 From: Ashafix Date: Sat, 26 Jun 2021 20:57:41 +0200 Subject: [PATCH] fix for download, input format changed --- machine_translation/train_transformer_tf2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/machine_translation/train_transformer_tf2.py b/machine_translation/train_transformer_tf2.py index b53f371..5b3d9e5 100644 --- a/machine_translation/train_transformer_tf2.py +++ b/machine_translation/train_transformer_tf2.py @@ -28,7 +28,8 @@ def maybe_download_and_read_file(url, filename): """ if not os.path.exists(filename): session = requests.Session() - response = session.get(url, stream=True) + response = session.get(url, stream=True, + headers={'User-Agent': 'Chrome/91.0.4472.106'}) CHUNK_SIZE = 32768 with open(filename, "wb") as f: @@ -37,7 +38,6 @@ def maybe_download_and_read_file(url, filename): f.write(chunk) zipf = ZipFile(filename) - filename = zipf.namelist() with zipf.open('fra.txt') as f: lines = f.read() @@ -73,7 +73,7 @@ def normalize_string(s): return s -raw_data_en, raw_data_fr = list(zip(*raw_data)) +raw_data_en, raw_data_fr, _ = list(zip(*raw_data)) raw_data_en = [normalize_string(data) for data in raw_data_en] raw_data_fr_in = [' ' + normalize_string(data) for data in raw_data_fr] raw_data_fr_out = [normalize_string(data) + ' ' for data in raw_data_fr]