Skip to content

Commit

Permalink
add adafactor optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
saippuakauppias committed Sep 24, 2019
1 parent e6afb28 commit a1bc418
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
22 changes: 13 additions & 9 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
except:
pass

from tensor2tensor.utils.adafactor import AdafactorOptimizer

from gpt_2_simple.src import model, sample, encoder, memory_saving_gradients
from gpt_2_simple.src.load_dataset import load_dataset, Sampler
from gpt_2_simple.src.accumulate import AccumulatingOptimizer
Expand All @@ -36,10 +38,10 @@ def download_file_with_progress(url_base, sub_dir, model_name, file_name):
file_name : str
name of file to get e.g. "hparams.json"
sub_dir: str
subdirectory inside which to get and copy locally eg. "models/124M"
subdirectory inside which to get and copy locally eg. "models/124M"
no trailing slash
url_base : str
Start of URL location specifying server and any base directories no
Start of URL location specifying server and any base directories no
trailing slash
e.g. "https://storage.googleapis.com/gpt-2"
"""
Expand All @@ -54,7 +56,7 @@ def download_file_with_progress(url_base, sub_dir, model_name, file_name):
for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
f.write(chunk)
pbar.update(DOWNLOAD_CHUNK_SIZE)


def download_gpt2(model_dir='models', model_name='124M'):
"""Downloads the GPT-2 model into the current directory
Expand All @@ -66,8 +68,8 @@ def download_gpt2(model_dir='models', model_name='124M'):
parent directory of model to download
model_name : str
name of the GPT-2 model to download.
As of 22 May 2019 one of "124M" or "355M" but may later include other
name of the GPT-2 model to download.
As of 22 May 2019 one of "124M" or "355M" but may later include other
model sizes
Adapted from https://github.com/openai/gpt-2/blob/master/download_model.py
Expand Down Expand Up @@ -101,7 +103,7 @@ def start_tf_sess(threads=-1, server=None):

if server is not None:
return tf.compat.v1.Session(target=server.target, config=config)

return tf.compat.v1.Session(config=config)


Expand Down Expand Up @@ -201,6 +203,8 @@ def maketree(path):
opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
elif optimizer == 'sgd':
opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)
elif optimizer == 'adafactor':
opt = AdafactorOptimizer(learning_rate=learning_rate)

if accumulate_gradients > 1:
if use_memory_saving_gradients:
Expand Down Expand Up @@ -305,7 +309,7 @@ def sample_batch():

if steps:
steps = int(steps)

try:
while True:
if steps > 0 and counter == (counter_base + steps):
Expand Down Expand Up @@ -645,7 +649,7 @@ def cmd():
)

# Explicit arguments

parser.add_argument(
'--mode', help='Mode for using the CLI (either "finetune" or "generate") [Required]', nargs='?')
parser.add_argument(
Expand Down Expand Up @@ -679,7 +683,7 @@ def cmd():
'--print_every', help="[finetune] After how many steps to print progress",
nargs='?', default=10, type=int)
parser.add_argument(
'--optimizer', help="[finetune] Optimizer to use for finetuning (adam or sgd)",
'--optimizer', help="[finetune] Optimizer to use for finetuning (adam or sgd or adafactor)",
nargs='?', default='adam')
parser.add_argument(
'--overwrite', help="[finetune] Overwrite existing model when continuing training",
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ regex
requests
tqdm
numpy
toposort
toposort
tensor2tensor

0 comments on commit a1bc418

Please sign in to comment.