Skip to content

Commit

Permalink
prevent accidental creation of CLIP models in float32 type when user …
Browse files Browse the repository at this point in the history
…wants float16
  • Loading branch information
AUTOMATIC1111 committed Jun 16, 2024
1 parent 7ee2114 commit b443fdc
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
6 changes: 3 additions & 3 deletions modules/models/sd3/sd3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def __init__(self, *args, **kwargs):
self.tokenizer = SD3Tokenizer()

with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)

self.weights_loaded = False

Expand Down
1 change: 1 addition & 0 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def set_model_fields(model):
if not hasattr(model, 'latent_channels'):
model.latent_channels = 4


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
Expand Down

0 comments on commit b443fdc

Please sign in to comment.