Skip to content
This repository has been archived by the owner on Oct 30, 2023. It is now read-only.

Commit

Permalink
Merge pull request #1 from DIAGNijmegen/MatinHz-patch-1
Browse files Browse the repository at this point in the history
Update tfdata.shuffle buffer_size
  • Loading branch information
anindox8 committed Nov 4, 2021
2 parents 999ee75 + bf60fab commit d6eaf0e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tf2.5/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
output_shapes = EXPECTED_IO_SHAPE) # Initialize TensorFlow Dataset
if str(args.CACHE_TDS_PATH)!='None':
train_gen = train_gen.cache(filename=(None if str(args.CACHE_TDS_PATH)=='None' else args.CACHE_TDS_PATH)) # Cache Dataset on Remote Server
train_gen = train_gen.shuffle(int(TRAIN_DATA_SAMPLES*0.50)) # Shuffle Samples
train_gen = train_gen.shuffle(int(TRAIN_DATA_SAMPLES)) # Shuffle Samples
train_gen = train_gen.map(lambda x,y: augment_tensors(x,y,args.AUGM_PARAMS,args.TRAIN_OBJ),
num_parallel_calls=multiprocessing.cpu_count())
train_gen = train_gen.batch(args.BATCH_SIZE) # Load Data in Batches
Expand Down

0 comments on commit d6eaf0e

Please sign in to comment.