-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request for Alpha-CLIP with LLaVA Web Demo and Local Demo #11
Comments
Sorry for the delay. But I'm now preparing for my final examination. I don't have time to do this now. By the way, Openxlab resource is limited for us to inference LLaVA-1.5-13b and is difficult to deploy llava. If you are in hurry, you can based on LLaVA code, and replace weight of CLIP it used in def rewrited_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
print("[Warning] using rewrited alpha forword")
global mask_torch
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
if mask_torch is None:
print("[Warning] no mask specified!")
alpha = torch.ones_like((pixel_values[:, [0], :, :])) * 1.9231
else:
alpha = mask_torch
patch_embeds = patch_embeds + self.patch_embedding_alpha(alpha)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings visual_encoder = self.model.model.vision_tower.vision_tower.vision_model
visual_encoder.embeddings.patch_embedding_alpha = torch.nn.Conv2d(in_channels=1,
out_channels=visual_encoder.embeddings.patch_embedding.out_channels,
kernel_size=visual_encoder.embeddings.patch_embedding.kernel_size,
stride=visual_encoder.embeddings.patch_embedding.stride,
bias=False)
visual_encoder.embeddings.forward = types.MethodType(rewrited_forward, visual_encoder.embeddings)
state_dict = torch.load('clip_l14@336_grit1m_fultune_8xe.pth')
converted_dict = collections.OrderedDict()
for k, v in state_dict.items():
if 'transformer.resblocks' in k:
new_key = k.replace('transformer.resblocks', 'encoder.layers').replace('attn', 'self_attn').replace('ln_1', 'layer_norm1').replace('ln_2', 'layer_norm2') \
.replace('c_fc', 'fc1').replace('c_proj', 'fc2')
if ('self_attn' in new_key) and ('out' not in new_key): # split qkv attn
if 'weight' in new_key :
converted_dict[new_key.replace('in_proj', 'q_proj')] = v[:1024, :]
converted_dict[new_key.replace('in_proj', 'k_proj')] = v[1024:2048, :]
converted_dict[new_key.replace('in_proj', 'v_proj')] = v[2048:, :]
else:
assert 'bias' in new_key
converted_dict[new_key.replace('in_proj', 'q_proj')] = v[:1024]
converted_dict[new_key.replace('in_proj', 'k_proj')] = v[1024:2048]
converted_dict[new_key.replace('in_proj', 'v_proj')] = v[2048:]
else:
converted_dict[new_key] = v
else:
new_key = k.replace('class_embedding', 'embeddings.class_embedding') \
.replace('conv1.weight', 'embeddings.patch_embedding.weight') \
.replace('positional_embedding', 'embeddings.position_embedding.weight') \
.replace('conv1_alpha.weight', 'embeddings.patch_embedding_alpha.weight') \
.replace('ln_pre.weight', 'pre_layrnorm.weight') \
.replace('ln_pre.bias', 'pre_layrnorm.bias') \
.replace('ln_post.weight', 'post_layernorm.weight') \
.replace('ln_post.bias', 'post_layernorm.bias')
converted_dict[new_key] = v
visual_encoder.load_state_dict(converted_dict, strict=False)
visual_encoder = visual_encoder.half().cuda() |
Thanks for your response, and good luck with your final examinations. Appreciate your help! |
Hello, do you have plans to release the training code about fine-tuning alpha clip with llava? The above code has offered the inferencing process, but I wonder if it is enough for fine-tuning? |
I would like to request the addition of a web demo and instructions for a local demo for Alpha-CLIP with LLaVA. Having both options would greatly enhance accessibility and usability for users interested in exploring Alpha-CLIP. Thank you!
The text was updated successfully, but these errors were encountered: