Skip to content

Commit

Permalink
Revert "fix bugs and optimizations"
Browse files Browse the repository at this point in the history
This reverts commit 108be15.
  • Loading branch information
aria1th committed Oct 20, 2022
1 parent 108be15 commit f89829e
Showing 1 changed file with 46 additions and 59 deletions.
105 changes: 46 additions & 59 deletions modules/hypernetworks/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=Fa
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))

self.linear = torch.nn.Sequential(*linears)

Expand Down Expand Up @@ -115,24 +115,11 @@ def weights(self):

for k, layers in self.layers.items():
for layer in layers:
layer.train()
res += layer.trainables()

return res

def eval(self):
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
for items in self.weights():
items.requires_grad = False

def train(self):
for k, layers in self.layers.items():
for layer in layers:
layer.train()
for items in self.weights():
items.requires_grad = True

def save(self, filename):
state_dict = {}

Expand Down Expand Up @@ -303,6 +290,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.sd_model.first_stage_model.to(devices.cpu)

hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True

losses = torch.zeros((32,))

last_saved_file = "<none>"
Expand All @@ -313,10 +304,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename

scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate)
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)

pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
hypernetwork.train()
for i, entries in pbar:
hypernetwork.step = i + ititial_step

Expand All @@ -337,9 +328,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

losses[hypernetwork.step % losses.shape[0]] = loss.item()

optimizer.zero_grad(set_to_none=True)
optimizer.zero_grad()
loss.backward()
del loss
optimizer.step()
mean_loss = losses.mean()
if torch.isnan(mean_loss):
Expand All @@ -356,47 +346,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
})

if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
torch.cuda.empty_cache()
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
with torch.no_grad():
hypernetwork.eval()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)

p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)

if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20

preview_text = p.prompt

processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None

if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)

if image is not None:
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"

hypernetwork.train()
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)

p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)

if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20

preview_text = p.prompt

processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None

if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)

if image is not None:
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"

shared.state.job_no = hypernetwork.step

Expand Down

0 comments on commit f89829e

Please sign in to comment.