Skip to content

Commit

Permalink
Support zeroing out text embeddings with the attention mask.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 9, 2024
1 parent 6cd8ffc commit 742d572
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS

Expand All @@ -90,6 +91,7 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le

self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = enable_attention_masks
self.zero_out_masked = zero_out_masked

self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
Expand Down Expand Up @@ -179,9 +181,12 @@ def forward(self, tokens):
self.transformer.set_input_embeddings(backup_embeds)

if self.layer == "last":
z = outputs[0]
z = outputs[0].float()
else:
z = outputs[1]
z = outputs[1].float()

if self.zero_out_masked and attention_mask is not None:
z *= attention_mask.unsqueeze(-1).float()

pooled_output = None
if len(outputs) >= 3:
Expand All @@ -190,7 +195,7 @@ def forward(self, tokens):
elif outputs[2] is not None:
pooled_output = outputs[2].float()

return z.float(), pooled_output
return z, pooled_output

def encode(self, tokens):
return self(tokens)
Expand Down

0 comments on commit 742d572

Please sign in to comment.