Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added batch image generation and extended logging. #35

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def parse_args():
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--img_sz", type=int, default=512)
parser.add_argument("--img_resz", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=16)

args = parser.parse_args()
return args
Expand Down Expand Up @@ -52,23 +53,27 @@ def parse_args():
file_list = get_file_list_from_csv(args.data_list)
params_str = pipeline.get_sdm_params()

t0 = time.time()
for i, file_info in enumerate(file_list):
img_name = file_info[0]
val_prompt = file_info[1]
t0 = time.perf_counter()
for batch_start in range(0, len(file_list), args.batch_size):
batch_end = batch_start + args.batch_size
img_names = [file_info[0] for file_info in file_list[batch_start: batch_end]]
val_prompts = [file_info[1] for file_info in file_list[batch_start: batch_end]]

print("---")
print(f"{i}/{len(file_list)} | {img_name} {val_prompt} | {args.num_inference_steps} steps")
print(params_str)
for i, (img_name, val_prompt) in enumerate(zip(img_names, val_prompts)):
print("---")
print(f"{batch_start + i}/{len(file_list)} | {img_name} {val_prompt} | {args.num_inference_steps} steps")
print(params_str)

img = pipeline.generate(prompt = val_prompt,
n_steps = args.num_inference_steps,
img_sz = args.img_sz)
img.save(os.path.join(save_dir_im512, img_name))
img.close()
imgs = pipeline.generate(prompt = val_prompts,
n_steps = args.num_inference_steps,
img_sz = args.img_sz)

for img, img_name in zip(imgs, img_names):
img.save(os.path.join(save_dir_im512, img_name))
img.close()

pipeline.clear()

change_img_size(save_dir_im512, save_dir_im256, args.img_resz)
print(f"{time.time()-t0} sec elapsed")
print(f"{(time.perf_counter()-t0):.2f} sec elapsed")

9 changes: 5 additions & 4 deletions src/generate_single_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def parse_args():
parser.add_argument("--use_dpm_solver", action='store_true', help='use DPMSolverMultistepScheduler')
parser.add_argument("--is_lora_checkpoint", action='store_true', help='specify whether to use LoRA finetuning')
parser.add_argument("--lora_weight_path", type=str, default=None, help='dir path including lora.pt and lora_config.json')

parser.add_argument("--batch_size", type=int, default=16)

args = parser.parse_args()
return args

Expand All @@ -48,16 +49,16 @@ def parse_args():
save_path = os.path.join(args.save_dir, args.val_prompt)
os.makedirs(save_path, exist_ok=True)

t0 = time.time()
t0 = time.perf_counter()
for i in range(args.num_images):
print(f"Generate {args.val_prompt} --- {i} | {args.num_inference_steps} steps")
img = pipeline.generate(prompt = args.val_prompt,
n_steps = args.num_inference_steps,
img_sz = args.img_sz)
img_sz = args.img_sz)[0]
img.save(os.path.join(save_path, f"{i}.png"))
img.close()

pipeline.clear()
print(f"Save to {save_path}")
print(f"{time.time()-t0} sec elapsed")
print(f"{(time.perf_counter()-t0):.2f} sec elapsed")

21 changes: 20 additions & 1 deletion src/kd_train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
import time
import copy

# try to import wandb
try:
import wandb
has_wandb = True
except:
has_wandb = False

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.15.0.dev0")

Expand Down Expand Up @@ -714,6 +721,9 @@ def collate_fn(examples):
add_hook(unet_teacher, acts_tea, mapping_layers_tea)
add_hook(unet, acts_stu, mapping_layers_stu)

# get wandb_tracker (if it exists)
wandb_tracker = accelerator.get_tracker("wandb")

for epoch in range(first_epoch, args.num_train_epochs):

unet.train()
Expand Down Expand Up @@ -808,7 +818,16 @@ def collate_fn(examples):
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
accelerator.log(
{
"train_loss": train_loss,
"loss_sd": loss_sd,
"loss_kd_output": loss_kd_output,
"loss_kd_feat": loss_kd_feat,
"lr": lr_scheduler.get_last_lr()[0]
},
step=global_step
)

if accelerator.is_main_process:
with open(csv_log_path, 'a') as logfile:
Expand Down
6 changes: 4 additions & 2 deletions src/utils/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# Copyright (c) 2023 Nota Inc. All Rights Reserved.
# Code modified from https://huggingface.co/blog/stable_diffusion
# ------------------------------------------------------------------------------------
from typing import Union, List

import diffusers
from diffusers import StableDiffusionPipeline
import torch
import gc
import json
from PIL import Image
from peft import LoraModel, LoraConfig, set_peft_model_state_dict

diffusers_version = int(diffusers.__version__.split('.')[1])
Expand Down Expand Up @@ -41,15 +43,15 @@ def set_pipe_and_generator(self):

self.generator = torch.Generator(device=self.device).manual_seed(self.seed)

def generate(self, prompt: str, n_steps: int, img_sz: int):
def generate(self, prompt: Union[str, List[str]], n_steps: int, img_sz: int) -> List[Image.Image]:
out = self.pipe(
prompt,
num_inference_steps=n_steps,
height = img_sz,
width = img_sz,
generator=self.generator,
)
return out.images[0]
return out.images

def _count_params(self, model):
return sum(p.numel() for p in model.parameters())
Expand Down