Skip to content

Commit

Permalink
fixed typo in example train_text_to_image.py (huggingface#3608)
Browse files Browse the repository at this point in the history
fixed typo
  • Loading branch information
kashif authored and Jimmy committed Apr 26, 2024
1 parent 9110dd8 commit db4719d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
"--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
)
parser.add_argument(
"--pretrained_model_name_or_path",
Expand Down Expand Up @@ -830,16 +830,16 @@ def collate_fn(examples):
noise += args.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)
if args.input_pertubation:
new_noise = noise + args.input_pertubation * torch.randn_like(noise)
if args.input_perturbation:
new_noise = noise + args.input_perturbation * torch.randn_like(noise)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.input_pertubation:
if args.input_perturbation:
noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
Expand Down

0 comments on commit db4719d

Please sign in to comment.