Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request for Alpha-CLIP with LLaVA Web Demo and Local Demo #11

Closed
X1AOX1A opened this issue Dec 19, 2023 · 3 comments
Closed

Request for Alpha-CLIP with LLaVA Web Demo and Local Demo #11

X1AOX1A opened this issue Dec 19, 2023 · 3 comments

Comments

@X1AOX1A
Copy link

X1AOX1A commented Dec 19, 2023

I would like to request the addition of a web demo and instructions for a local demo for Alpha-CLIP with LLaVA. Having both options would greatly enhance accessibility and usability for users interested in exploring Alpha-CLIP. Thank you!

@SunzeY
Copy link
Owner

SunzeY commented Dec 19, 2023

Sorry for the delay. But I'm now preparing for my final examination. I don't have time to do this now. By the way, Openxlab resource is limited for us to inference LLaVA-1.5-13b and is difficult to deploy llava. If you are in hurry, you can based on LLaVA code, and replace weight of CLIP it used in llava/serve/model_worker.py with alpha-clip. we will open official implementation after my final examinations.

def rewrited_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
    print("[Warning] using rewrited alpha forword")
    global mask_torch
    batch_size = pixel_values.shape[0]
    patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, width, grid, grid]
    if mask_torch is None:
        print("[Warning] no mask specified!")
        alpha = torch.ones_like((pixel_values[:, [0], :, :])) * 1.9231
    else:
        alpha = mask_torch
    patch_embeds = patch_embeds + self.patch_embedding_alpha(alpha)
    patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

    class_embeds = self.class_embedding.expand(batch_size, 1, -1)
    embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
    embeddings = embeddings + self.position_embedding(self.position_ids)
    return embeddings
        visual_encoder = self.model.model.vision_tower.vision_tower.vision_model
        
        visual_encoder.embeddings.patch_embedding_alpha = torch.nn.Conv2d(in_channels=1,
                                                            out_channels=visual_encoder.embeddings.patch_embedding.out_channels, 
                                                            kernel_size=visual_encoder.embeddings.patch_embedding.kernel_size, 
                                                            stride=visual_encoder.embeddings.patch_embedding.stride, 
                                                            bias=False)
        visual_encoder.embeddings.forward = types.MethodType(rewrited_forward, visual_encoder.embeddings)
        state_dict = torch.load('clip_l14@336_grit1m_fultune_8xe.pth')
        converted_dict = collections.OrderedDict()
        for k, v in state_dict.items():
            if 'transformer.resblocks' in k:
                new_key = k.replace('transformer.resblocks', 'encoder.layers').replace('attn', 'self_attn').replace('ln_1', 'layer_norm1').replace('ln_2', 'layer_norm2') \
                           .replace('c_fc', 'fc1').replace('c_proj', 'fc2')
                if ('self_attn' in new_key) and ('out' not in new_key): # split qkv attn
                    if 'weight' in new_key :
                        converted_dict[new_key.replace('in_proj', 'q_proj')] = v[:1024, :]
                        converted_dict[new_key.replace('in_proj', 'k_proj')] = v[1024:2048, :]
                        converted_dict[new_key.replace('in_proj', 'v_proj')] = v[2048:, :]
                    else:
                        assert 'bias' in new_key
                        converted_dict[new_key.replace('in_proj', 'q_proj')] = v[:1024]
                        converted_dict[new_key.replace('in_proj', 'k_proj')] = v[1024:2048]
                        converted_dict[new_key.replace('in_proj', 'v_proj')] = v[2048:]
                else:
                    converted_dict[new_key] = v
            else:
                new_key = k.replace('class_embedding', 'embeddings.class_embedding') \
                           .replace('conv1.weight', 'embeddings.patch_embedding.weight') \
                           .replace('positional_embedding', 'embeddings.position_embedding.weight') \
                           .replace('conv1_alpha.weight', 'embeddings.patch_embedding_alpha.weight') \
                           .replace('ln_pre.weight', 'pre_layrnorm.weight') \
                           .replace('ln_pre.bias', 'pre_layrnorm.bias') \
                           .replace('ln_post.weight', 'post_layernorm.weight') \
                           .replace('ln_post.bias', 'post_layernorm.bias')
                converted_dict[new_key] = v

        visual_encoder.load_state_dict(converted_dict, strict=False)
        visual_encoder = visual_encoder.half().cuda()

@X1AOX1A
Copy link
Author

X1AOX1A commented Dec 19, 2023

Thanks for your response, and good luck with your final examinations. Appreciate your help!

@lxysl
Copy link

lxysl commented Apr 19, 2024

Hello, do you have plans to release the training code about fine-tuning alpha clip with llava? The above code has offered the inferencing process, but I wonder if it is enough for fine-tuning?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants