Skip to content

Commit

Permalink
Merge pull request #10 from FizzleDorf/ait_fix
Browse files Browse the repository at this point in the history
Ait fix
  • Loading branch information
FizzleDorf authored Nov 5, 2023
2 parents e0350f6 + 0a18145 commit 313cf02
Showing 1 changed file with 15 additions and 71 deletions.
86 changes: 15 additions & 71 deletions ait_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def set_weights(self, sd):
constants = map_unet_params(sd)
self.exe_module.set_many_constants_with_tensors(constants)

def apply_model(self, xc, t, context, y=None, control=None, transformer_options=None):
def apply_model(self, xc, t, c_crossattn, y=None, control=None, transformer_options=None, **kwargs):
xc = xc.permute((0, 2, 3, 1)).half().contiguous()
output = [torch.empty_like(xc)]
inputs = {"x": xc, "timesteps": t.half(), "context": context.half()}
inputs = {"x": xc, "timesteps": t.half(), "context": c_crossattn.half()}
if y is not None:
inputs['y'] = y.half()
self.exe_module.run_with_tensors(inputs, output, graph_mode=False)
Expand All @@ -69,28 +69,23 @@ def __call__(self, model_function, params):
self.ait_model.set_weights(sd)

c_concat = params["c"].get("c_concat", None)

# The BaseModel instance
inner_model = self.model.model
x = params["input"]
sigma = params["timestep"]
xc = inner_model.model_sampling.calculate_input(sigma, x)
t = inner_model.model_sampling.timestep(sigma).float()
t = self.model.model.model_sampling.timestep(sigma).float()
x = self.model.model.model_sampling.calculate_input(sigma, params["input"])
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
context = params["c"].get("c_crossattn")
y = params["c"].get("y")
control = params["c"].get("control")
transformer_options = params["c"].get("transformer_options")
out = self.ait_model.apply_model(xc, t, context, y, control, transformer_options)
return inner_model.model_sampling.calculate_denoised(sigma, out, x)
xc = torch.cat([x] + [c_concat], dim=1)
else:
xc = x
model_output = self.ait_model.apply_model(x, t, **params["c"]).float()
return self.model.model.model_sampling.calculate_denoised(sigma, model_output, params["input"])

def to(self, a):
if self.ait_model is not None:
if a == torch.device("cpu"):
self.ait_model.inner_model.unload_module()
self.ait_model.unload_model()
self.ait_model = None
print("unloaded AIT")
return self

class AIT_Unet_Loader:
@classmethod
Expand All @@ -108,62 +103,11 @@ def load_ait(self, model, ait_name):
patch = AITPatch(model, ait_path)
model_ait = model.clone()
model_ait.set_model_unet_function_wrapper(patch)
print(patch)
return (model_ait,)

class AIT_VAE_Encode_Loader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ),
"vae": ("VAE",),
"ait_name": (folder_paths.get_filename_list("ait"), ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "load_ait"

CATEGORY = "loaders/AIT"

@staticmethod
def vae_encode_crop_pixels(self, pixels):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels

def load_ait(self, pixels, ait_name, vae):
resolution = max(pixels.shape[1], pixels.shape[2])
model_type = "vae_encode"

# Clear any previously loaded VAE models
if len(AITemplate.vae.keys()) > 0:
to_delete = list(AITemplate.vae.keys())
for key in to_delete:
del AITemplate.vae[key]

# Load the VAE module using the provided "ait_name"
module_filename = folder_paths.get_full_path("ait", ait_name)
if module_filename not in AITemplate.vae:
AITemplate.vae[module_filename] = AITemplate.loader.load_module(module_filename)

AITemplate.vae[module_filename] = AITemplate.loader.apply_vae(
aitemplate_module=AITemplate.vae[module_filename],
vae=AITemplate.loader.compvis_vae(vae.first_stage_model.state_dict()),
encoder=True,
)

# Perform any required image processing here
pixels = self.vae_encode_crop_pixels(pixels)
pixels = pixels[:, :, :, :3]
pixels = pixels.movedim(-1, 1)
pixels = 2. * pixels - 1.

samples = vae_inference(AITemplate.vae[module_filename], pixels, encoder=True)
samples = samples.cpu()

# Unload the module after inference
del AITemplate.vae[module_filename]
torch.cuda.empty_cache()

return ({"samples": samples},)
NODE_CLASS_MAPPINGS = {
"AIT_Unet_Loader": AIT_Unet_Loader,
}

0 comments on commit 313cf02

Please sign in to comment.