Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a way to only use the text encoder ? #113

Closed
ranran9991 opened this issue Jun 9, 2021 · 11 comments
Closed

Is there a way to only use the text encoder ? #113

ranran9991 opened this issue Jun 9, 2021 · 11 comments

Comments

@ranran9991
Copy link

Hey!
I'd like to use only one part of the model, specifically the text encoder in my work. I don't want to store the whole model in GPU memory just to use the text encoding part, is there a simple way to do that? or will I have to dive into the code myself

Thanks for the help ! :)

@vinson2233
Copy link

What I have done is to set the model.visual = None to remove all the visual part. But this will raise error since dtype property is dependent on the .visual part.

CLIP/clip/model.py

Lines 332 to 334 in cfcffb9

@property
def dtype(self):
return self.visual.conv1.weight.dtype

I think if we can set all the visual part into None except for model.visual.conv1.weight then the encode_text will work perfectly without the need to store the visual part.

@ranran9991
Copy link
Author

What I have done is to set the model.visual = None to remove all the visual part. But this will raise error since dtype property is dependent on the .visual part.

CLIP/clip/model.py

Lines 332 to 334 in cfcffb9

@property
def dtype(self):
return self.visual.conv1.weight.dtype

I think if we can set all the visual part into None except for model.visual.conv1.weight then the encode_text will work perfectly without the need to store the visual part.

Thank you for the help, hopefully this feature will be added soon

@lonngxiang
Copy link

same needs

@jongwook
Copy link
Collaborator

You can replace self.visual.conv1.weight.dtype with next(self.parameters()).dtype and such, which will avoid the error. I'll plan to replace it in the next round of updates.

@lonngxiang
Copy link

You can replace self.visual.conv1.weight.dtype with next(self.parameters()).dtype and such, which will avoid the error. I'll plan to replace it in the next round of updates.

thumb up

@lonngxiang
Copy link

lonngxiang commented Jul 5, 2021

finally, I succeed @jongwook @vinson2233 @ranran9991

model.py

class CLIP(nn.Module):
    def __init__(self,
                embed_dim:int,
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

   self.visual=None

   @property
    def dtype(self):
        return next(self.parameters()).dtype

   def forward(self, text):
        # image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        # logit_scale = self.logit_scale.exp()
        # logits_per_image = logit_scale * image_features @ text_features.t()
        # logits_per_text = logit_scale * text_features @ image_features.t()

        # shape = [global_batch_size, global_batch_size]
        # return logits_per_image, logits_per_text
        return text_features

save.py

device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

dict_trained = model.state_dict()    # trained model
trained_lst = list(dict_trained.keys())

model_txt = CLIP(embed_dim=512,context_length=77,vocab_size=49408,transformer_width=512,transformer_heads=16,transformer_layers=6)
dict_txt = model_txt.state_dict()
print(dict_txt)
for key in dict_txt:
    dict_txt[key] = dict_trained[key]
model_txt.load_state_dict(dict_txt) 
torch.save(model_txt, "./single_model_text1.pkl") 

@vinson2233
Copy link

@lonngxiang To highlight the code part, you can use ``` instead of `, put it on top and bottom part of code, all will be highlighted as code.
But I get the idea. Thanks.

@lonngxiang
Copy link

@lonngxiang To highlight the code part, you can use ``` instead of `, put it on top and bottom part of code, all will be highlighted as code.
But I get the idea. Thanks.

ok,tks

@laurenspriem
Copy link

laurenspriem commented Feb 28, 2023

finally, I succeed @jongwook @vinson2233 @ranran9991

model.py

class CLIP(nn.Module):
    def __init__(self,
                embed_dim:int,
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

   self.visual=None

   @property
    def dtype(self):
        return next(self.parameters()).dtype

   def forward(self, text):
        # image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        # logit_scale = self.logit_scale.exp()
        # logits_per_image = logit_scale * image_features @ text_features.t()
        # logits_per_text = logit_scale * text_features @ image_features.t()

        # shape = [global_batch_size, global_batch_size]
        # return logits_per_image, logits_per_text
        return text_features

save.py

device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

dict_trained = model.state_dict()    # trained model
trained_lst = list(dict_trained.keys())

model_txt = CLIP(embed_dim=512,context_length=77,vocab_size=49408,transformer_width=512,transformer_heads=16,transformer_layers=6)
dict_txt = model_txt.state_dict()
print(dict_txt)
for key in dict_txt:
    dict_txt[key] = dict_trained[key]
model_txt.load_state_dict(dict_txt) 
torch.save(model_txt, "./single_model_text1.pkl") 

I followed this approach and got a text encoder. However, the embeddings that the model gives are completely wrong. @lonngxiang have you verified that the resulting embeddings match those of the original model?

@jongwook is the above method still the recommended approach, or is there a better way by now?

@laurenspriem
Copy link

laurenspriem commented Feb 28, 2023

Since the above solution didn't work for me, I've used another workaround that was kind of suggested earlier in this thread.

Basically, you can strip away most of the image encoder by setting model.visual.transformer = None. Contrary to model.visual = None, this doesn't give an error.

Downsides are that you're still left with some useless weights from the image encoder, and that you have to use model.encode_text(text_input) instead of just model(text_input).

@ranran9991
Copy link
Author

Since the above solution didn't work for me, I've used another workaround that was kind of suggested earlier in this thread.

Basically, you can strip away most of the image encoder by setting model.visual.transformer = None. Contrary to model.visual = None, this doesn't give an error.

Downsides are that you're still left with some useless weights from the image encoder, and that you have to use model.encode_text(text_input) instead of just model(text_input).

This is the workaround that I use as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants