Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
peizesun committed Apr 10, 2023
1 parent b1314d9 commit fe0f5d7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
20 changes: 13 additions & 7 deletions fastrcnn/modeling/clip_rcnn.py
Expand Up @@ -24,7 +24,14 @@


def build_clip_rcnn(clip_type='RN50'):
clip_rcnn = CLIP_RCNN(clip_type=clip_type)
pooler_res_dict = {
"RN50": 14,
"RN50x4": 18,
"RN50x16": 24,
"RN50x64": 28,
}
pooler_resolution = pooler_res_dict[clip_type]
clip_rcnn = CLIP_RCNN(clip_type=clip_type, pooler_resolution=pooler_resolution)
clip_rcnn.eval()
return clip_rcnn

Expand All @@ -37,12 +44,12 @@ class CLIP_RCNN(nn.Module):
def __init__(
self,
clip_type,
softmax_t: float = 0.01,
pooler_resolution: int = 14,
pooler_resolution,
pooler_scales: int = 16,
sampling_ratio: int = 0,
pooler_type: str = "ROIAlignV2",
canonical_box_size: int = 224,
softmax_t: float = 0.01,
):
super().__init__()
self.register_buffer("pixel_mean_clip",
Expand Down Expand Up @@ -75,11 +82,10 @@ def forward_clip(self, image, boxes, text_prompt):
features = self.clip_res_c4_backbone(imageList_clip.tensor)
text_embed = self.get_text_embeddings(text_prompt)
clip_scores = self.clip_res5_roi_heads(features, boxes, text_embed)
return clip_scores.cpu().tolist()
return clip_scores.cpu()

def get_text_embeddings(self, vocabulary, prefix_prompt='a '):
if not isinstance(vocabulary, list):
vocabulary = [vocabulary]
vocabulary = vocabulary.split(',')
texts = [prefix_prompt + x.lower().replace(':', ' ') for x in vocabulary]
texts_aug = texts + ['background']
emb = self.text_encoder(texts_aug).permute(1, 0)
Expand Down Expand Up @@ -110,7 +116,7 @@ def clip_res5_roi_heads(self, features, boxes, text_embed):
region_features = F.normalize(region_features, p=2, dim=-1)

similarity = ((1 / self.softmax_t) * region_features @ text_embed).softmax(dim=-1)
clip_scores = similarity[:,0]
clip_scores = similarity[:,:-1]
return clip_scores


Expand Down
10 changes: 5 additions & 5 deletions fastrcnn/modeling/text_encoder.py
Expand Up @@ -172,15 +172,15 @@ def build_text_encoder(pretrain=True, visual_type="RN50"):
"visual_type": ["embed_dim", "context_length", "vocab_size",
"transformer_width", "transformer_heads", "transformer_layers"],
"RN50": [1024, 77, 49408, 512, 8, 12],
"ViT-B/32": [512, 77, 49408, 512, 8, 12],
"RN50x4": [640, 77, 49408, 640, 10, 12],
"RN50x16": [768, 77, 49408, 768, 12, 12],
"RN50x64": [1024, 77, 49408, 1024, 16, 12],
}
text_encoder = CLIPTEXT(**{k: v for k, v in zip(clip_dict['visual_type'], clip_dict[visual_type])})
if pretrain:
import clip
if visual_type == 'RN50':
pretrained_model, _ = clip.load("RN50", device='cpu')
elif visual_type == 'ViT-B/32':
pretrained_model, _ = clip.load("ViT-B/32", device='cpu')
if visual_type in clip_dict:
pretrained_model, _ = clip.load(visual_type, device='cpu')
else:
raise NotImplementedError

Expand Down

0 comments on commit fe0f5d7

Please sign in to comment.