Skip to content

Commit

Permalink
updated to use f string formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
azograby committed May 7, 2024
1 parent 1b81339 commit 5b44dcd
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions use-cases/text-to-image-fine-tuning/code/train
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ def train():
try:
input_files = [ os.path.join(training_path, file) for file in os.listdir(training_path) ]
if len(input_files) == 0:
raise ValueError(('There are no files in {}.\n' +
'This usually indicates that the channel ({}) was incorrectly specified,\n' +
'the data specification in S3 was incorrectly specified or the role specified\n' +
'does not have permission to access the data.').format(training_path, channel_name))
error_message = (
f"There are no files in {training_path}.\n"
f"This usually indicates that the channel ({channel_name}) was incorrectly specified,\n"
"the data specification in S3 was incorrectly specified or the role specified\n"
"does not have permission to access the data."
)
raise ValueError(error_message)

# Call the program that was installed in the training container, which uses the kohya-ss libraries to train
# Stable Diffusion XL given the kohya-sdxl-config.toml file that is present in the training S3 bucket
Expand All @@ -44,11 +47,14 @@ def train():
print('Training complete.')
except Exception as e:
# Write out an error file. This will be returned as the failureReason in the DescribeTrainingJob result
trc = traceback.format_exc()
with open(os.path.join(output_path, 'failure'), 'w') as s:
s.write('Exception during training: ' + str(e) + '\n' + trc)
traceback_str = traceback.format_exc()
failure_log_path = os.path.join(output_path, 'failure')

with open(failure_log_path, 'w') as log_file:
log_file.write(f'Exception during training: {str(e)}\n{traceback_str}')
# Printing this causes the exception to be in the training job logs, as well.
print('Exception during training: ' + str(e) + '\n' + trc, file=sys.stderr)
print(f'Exception during training: {str(e)}\n{traceback_str}', file=sys.stderr)

# A non-zero exit code causes the training job to be marked as Failed.
sys.exit(255)

Expand Down

0 comments on commit 5b44dcd

Please sign in to comment.