Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize input image according to text prompt using guidance #287

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions guidance/if_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def seed_everything(seed):


class IF(nn.Module):
def __init__(self, device, vram_O, t_range=[0.02, 0.98]):
def __init__(self, device, vram_O, t_range=[0.02, 0.98], fp16=True):
super().__init__()

self.device = device
Expand All @@ -45,8 +45,10 @@ def __init__(self, device, vram_O, t_range=[0.02, 0.98]):

is_torch2 = torch.__version__[0] == '2'

self.precision_t = torch.float16 if fp16 else torch.float32

# Create model
pipe = IFPipeline.from_pretrained(model_key, variant="fp16", torch_dtype=torch.float16)
pipe = IFPipeline.from_pretrained(model_key, variant="fp16" if fp16 else "fp32", torch_dtype=self.precision_t)
if not is_torch2:
pipe.enable_xformers_memory_efficient_attention()

Expand Down Expand Up @@ -175,6 +177,50 @@ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num
return imgs


def img_opt(self, text_embeddings, images, guidance_scale=40.0, sorted=True):

if sorted:
timesteps = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device).sort()[0]
else:
timesteps = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)

with torch.no_grad():

# add noise to latents using the timesteps
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, timesteps).to(self.device)

# predict the noise residual
model_input = torch.cat([images_noisy] * 2)
model_input = self.scheduler.scale_model_input(model_input, timesteps)
tt = torch.cat([timesteps] * 2)
text_input = text_embeddings.repeat_interleave(len(images_noisy), 0)
noise_pred = []
# To ensure batch_size=1
for s, t, text in zip(model_input, tt, text_input):
noise_pred.append(self.unet(sample=s[None, ...],
timestep=t[None, ...],
encoder_hidden_states=text[None, ...]).sample)

noise_pred = torch.cat(noise_pred)

# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# w(t), sigma_t^2
w = (1 - self.alphas[timesteps])
grad = w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)

# since we omitted an item in grad, we need to use the custom function to specify the gradient
loss = SpecifyGradient.apply(images, grad)

return images, loss


if __name__ == '__main__':

import argparse
Expand Down
44 changes: 42 additions & 2 deletions guidance/sd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=Fa
# see zero123_utils.py's version for a simpler implementation.
alphas = self.scheduler.alphas.to(latents)
total_timesteps = self.max_step - self.min_step + 1
index = total_timesteps - t.to(latents.device) - 1
index = total_timesteps - t.to(latents.device) - 1
b = len(noise_pred)
a_t = alphas[index].reshape(b,1,1,1).to(self.device)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))

Expand Down Expand Up @@ -242,6 +242,46 @@ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num

return imgs

def img_opt(self, text_embeddings, latents, guidance_scale=40.0, sorted=True):

with torch.no_grad():

if sorted:
timesteps = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device).sort()[0]
else:
timesteps = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn_like(latents)
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps).to(self.device)

# predict the noise residual
samples_unet = torch.cat([noisy_latents] * 2)
timesteps_unet = torch.cat([timesteps] * 2)
text_unet = text_embeddings.repeat_interleave(len(noisy_latents), 0)
noise_pred = []
# To ensure batch_size=1
for s, t, text in zip(samples_unet, timesteps_unet, text_unet):
noise_pred.append(self.unet(sample=s[None, ...],
timestep=t[None, ...],
encoder_hidden_states=text[None, ...]).sample)

noise_pred = torch.cat(noise_pred)

# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# w(t), sigma_t^2
w = (1 - self.alphas[timesteps])
grad = w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)

# since we omitted an item in grad, we need to use the custom function to specify the gradient
loss = SpecifyGradient.apply(latents, grad)

return latents, loss


if __name__ == '__main__':

Expand Down
32 changes: 27 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def __call__ (self, parser, namespace, values, option_string = None):
parser.add_argument('--exp_start_iter', type=int, default=None, help="start iter # for experiment, to calculate progressive_view and progressive_level")
parser.add_argument('--exp_end_iter', type=int, default=None, help="end iter # for experiment, to calculate progressive_view and progressive_level")

# optimize input image acc to prompt
parser.add_argument('--img_opt', action='store_true', help="optimize image acc to text prompt")

opt = parser.parse_args()

if opt.O:
Expand Down Expand Up @@ -196,7 +199,7 @@ def __call__ (self, parser, namespace, values, option_string = None):
else:
# use stable-diffusion when providing both text and image
opt.guidance = ['SD', 'clip']

if not opt.dont_override_stuff:
opt.guidance_scale = 10
opt.t_range = [0.2, 0.6]
Expand All @@ -212,7 +215,7 @@ def __call__ (self, parser, namespace, values, option_string = None):
opt.latent_iter_ratio = 0
if not opt.dont_override_stuff:
opt.albedo_iter_ratio = 0

# make shape init more stable
opt.progressive_view = True
opt.progressive_level = True
Expand Down Expand Up @@ -249,7 +252,7 @@ def __call__ (self, parser, namespace, values, option_string = None):
opt.w = int(opt.w * opt.dmtet_reso_scale)
opt.known_view_scale = 1

if not opt.dont_override_stuff:
if not opt.dont_override_stuff:
opt.t_range = [0.02, 0.50] # ref: magic3D

if opt.images is not None:
Expand All @@ -271,7 +274,7 @@ def __call__ (self, parser, namespace, values, option_string = None):
if not opt.dont_override_stuff:
# disable as they disturb progressive view
opt.jitter_pose = False

opt.uniform_sphere_rate = 0
# back up full range
opt.full_radius_range = opt.radius_range
Expand Down Expand Up @@ -334,6 +337,25 @@ def __call__ (self, parser, namespace, values, option_string = None):
if opt.save_mesh:
trainer.save_mesh()

elif opt.img_opt:

if 'SD' in opt.guidance:
from guidance.sd_utils import StableDiffusion
guide = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range)
name = f"SD"

elif 'IF' in opt.guidance:
from guidance.if_utils import IF
guide = IF(device, opt.vram_O, opt.t_range, opt.fp16)
name = f"DeepFloydIF"

seed = opt.seed or 0

name += f"_fp{16 if opt.fp16 else 32}_iters{opt.iters}_lr{opt.lr}_seed{seed}_{opt.text.replace(' ', '_')}"

img_opt = ImageOpt(opt, guide, name, seed)
img_opt.train()

elif opt.test:
guidance = None # no need to load guidance model at test

Expand Down Expand Up @@ -376,7 +398,7 @@ def __call__ (self, parser, namespace, values, option_string = None):

if 'IF' in opt.guidance:
from guidance.if_utils import IF
guidance['IF'] = IF(device, opt.vram_O, opt.t_range)
guidance['IF'] = IF(device, opt.vram_O, opt.t_range, fp16=opt.fp16)

if 'zero123' in opt.guidance:
from guidance.zero123_utils import Zero123
Expand Down
103 changes: 103 additions & 0 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,3 +1257,106 @@ def get_GPU_mem():
mems.append(int(((mem_total - mem_free)/1024**3)*1000)/1000)
mem += mems[-1]
return mem, mems


class ImageOpt():

def __init__(self, opt, guide, name="SD_imgopt", seed=0):

self.opt = opt
self.guide = guide
self.name = name

for p in guide.parameters():
p.requires_grad = False

torch.manual_seed(seed)
if opt.IF:
images = torch.rand((opt.batch_size, 3, 64, 64), device=guide.device, dtype=guide.precision_t) * 2 - 1
else:
images = torch.randn((opt.batch_size, 4, 64, 64), device=guide.device, dtype=guide.precision_t)

self.images = images.requires_grad_(True)

self.optimizer = optim.Adam([self.images], lr=opt.lr) # naive adam

self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler

self.scaler = torch.cuda.amp.GradScaler(enabled=self.opt.fp16)

self.sample_freq = max(1, self.opt.iters//20)

uncond_embeddings = guide.get_text_embeds([""])
text_embeddings = guide.get_text_embeds([self.opt.text])
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

def train(self):

self.local_step = 0
samples = []

for i in tqdm.tqdm(range(self.opt.iters)):

self.local_step += 1

self.optimizer.zero_grad()

with torch.cuda.amp.autocast(enabled=self.opt.fp16):
self.images, loss = self.guide.img_opt(self.text_embeddings, self.images)

# hooked grad clipping for RGB space
if self.opt.grad_clip_rgb >= 0:
def _hook(grad):
if self.opt.fp16:
# correctly handle the scale
grad_scale = self.scaler._get_scale_async()
return grad.clamp(grad_scale * -self.opt.grad_clip_rgb, grad_scale * self.opt.grad_clip_rgb)
else:
return grad.clamp(-self.opt.grad_clip_rgb, self.opt.grad_clip_rgb)
self.images.register_hook(_hook)
# self.images.retain_grad()

self.scaler.scale(loss).backward()

self.post_train_step()
self.scaler.step(self.optimizer)
self.scaler.update()
self.lr_scheduler.step()

if i % self.sample_freq == 0:
if self.opt.IF:
samples.append(self.images.detach().cpu().permute(0, 2, 3, 1).numpy())
else:
with torch.no_grad():
images = []
for latent in self.images:
images.append(self.guide.decode_latents(latent[None, ...].detach()))
image = torch.cat(images).cpu().permute(0, 2, 3, 1).numpy()
samples.append(image)

gif = self.make_gif_from_imgs(samples, resize=(4 if "SD" in self.name else 1))
os.makedirs(self.opt.workspace, exist_ok=True)
imageio.mimwrite(os.path.join(self.opt.workspace, f'{self.name}_imgopt.mp4'), gif, fps=10, quality=8, macro_block_size=1)

def make_gif_from_imgs(self, frames, resize=1.0, upto=None, repeat_first=2, repeat_last=5, skip=1,
f=0, s=0.75, t=2):
imgs = []
from PIL import Image
for i, img in tqdm.tqdm(enumerate(frames[:upto:skip]), total=len(frames[:upto:skip])):
img = np.moveaxis(img, 0, 1).reshape(img.shape[1], -1, 3)
img = np.array(Image.fromarray((img*255).astype(np.uint8)).resize((int(img.shape[1]/resize), int(img.shape[0]/resize)), Image.Resampling.LANCZOS))
text = f"{i*self.sample_freq:05d}"
img = cv2.putText(img=img, text=text, org=(0, 20), fontFace=f, fontScale=s, color=(0,0,0), thickness=t)
imgs.append(img)
# Save gif
return [imgs[0]]*repeat_first + imgs + [imgs[-1]]*repeat_last

def post_train_step(self):

# unscale grad before modifying it!
# ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping
self.scaler.unscale_(self.optimizer)

# clip grad
if self.opt.grad_clip >= 0:
torch.nn.utils.clip_grad_value_(self.images, self.opt.grad_clip)
8 changes: 8 additions & 0 deletions scripts/run_img_opt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SD fp32
CUDA_VISIBLE_DEVICES=0 python main.py --img_opt --text "A high quality 3D render of a strawberry" --workspace trial_imgopt --batch_size 10 --iters 100 --lr 0.1 --seed 0
# DeepFloydIF fp32
CUDA_VISIBLE_DEVICES=0 python main.py --IF --img_opt --text "A high quality 3D render of a strawberry" --workspace trial_imgopt --batch_size 10 --iters 100 --lr 0.1 --seed 0
# # SD fp16
# CUDA_VISIBLE_DEVICES=0 python main.py -O --img_opt --text "A high quality 3D render of a strawberry" --workspace trial_imgopt --batch_size 10 --iters 100 --lr 0.1 --seed 0
# # DeepFloydIF fp16
# CUDA_VISIBLE_DEVICES=0 python main.py --IF -O --img_opt --text "A high quality 3D render of a strawberry" --workspace trial_imgopt --batch_size 10 --iters 100 --lr 0.1 --seed 0