Skip to content

Commit

Permalink
Inference is twice faster by calling CLIP just once and caching the r…
Browse files Browse the repository at this point in the history
…esults
  • Loading branch information
GuyTevet committed Apr 12, 2024
1 parent 63edacf commit 94c173f
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -127,3 +127,5 @@ dmypy.json

# Pyre type checker
.pyre/

save/
4 changes: 4 additions & 0 deletions diffusion/gaussian_diffusion.py
Expand Up @@ -637,6 +637,10 @@ def p_sample_loop(
if dump_steps is not None:
dump = []

if 'text' in model_kwargs['y'].keys():
# encoding once instead of each iteration saves lots of time
model_kwargs['y']['text_embed'] = model.encode_text(model_kwargs['y']['text'])

for i, sample in enumerate(self.p_sample_loop_progressive(
model,
shape,
Expand Down
1 change: 1 addition & 0 deletions model/cfg_sampler.py
Expand Up @@ -20,6 +20,7 @@ def __init__(self, model):
self.nfeats = self.model.nfeats
self.data_rep = self.model.data_rep
self.cond_mode = self.model.cond_mode
self.encode_text = self.model.encode_text

def forward(self, x, timesteps, y=None):
cond_mode = self.model.cond_mode
Expand Down
5 changes: 4 additions & 1 deletion model/mdm.py
Expand Up @@ -148,7 +148,10 @@ def forward(self, x, timesteps, y=None):

force_mask = y.get('uncond', False)
if 'text' in self.cond_mode:
enc_text = self.encode_text(y['text'])
if 'text_embed' in y.keys(): # caching option
enc_text = y['text_embed']
else:
enc_text = self.encode_text(y['text'])
emb += self.embed_text(self.mask_cond(enc_text, force_mask=force_mask))
if 'action' in self.cond_mode:
action_emb = self.embed_action(y['action'])
Expand Down

0 comments on commit 94c173f

Please sign in to comment.