Skip to content

Commit

Permalink
Merge pull request #2 from hammad26/master
Browse files Browse the repository at this point in the history
Fetching relevants position embeddings instead of all. May increase s…
  • Loading branch information
SuperSupermoon committed Feb 4, 2022
2 parents b5d0e6b + a0666c1 commit 9c258c7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
31 changes: 16 additions & 15 deletions downstream_task/report_generation_and_vqa/generation_decode.py
Expand Up @@ -16,12 +16,12 @@
import torchvision
import csv
from transformers import BertTokenizer
from pytorch_pretrained_bert.modeling_like_cxrbert import BertForSeq2SeqDecoder
from pytorch_pretrained_bert.model import BertForSeq2SeqDecoder

import sc.seq2seq_loader_itm as seq2seq_loader
from sc.bleu import language_eval_bleu
from misc.data_parallel import DataParallelImbalance
from sc.image_embedding import Img_patch_embedding, fully_sampling, random_sampling
# import sc.seq2seq_loader_itm as seq2seq_loader
from data_loader import Preprocess4Seq2seqDecoder
from bleu import language_eval_bleu
from data_parallel import DataParallelImbalance


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
Expand Down Expand Up @@ -148,7 +148,7 @@ def main():

bi_uni_pipeline = []
# def __init__(self, tokenizer, max_len, max_txt_length, new_segment_ids=False, mode="s2s", len_vis_input=None):
bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(tokenizer, args.max_seq_length,
bi_uni_pipeline.append(Preprocess4Seq2seqDecoder(tokenizer, args.max_seq_length,
max_txt_length=args.max_txt_length, new_segment_ids=args.new_segment_ids,
mode='s2s', len_vis_input=args.len_vis_input))

Expand All @@ -175,14 +175,15 @@ def main():
w_list.append(w)
forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))

model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model,
max_position_embeddings=args.max_position_embeddings, config_path=args.config_path,
state_dict={}, args=args, num_labels=cls_num_labels,
type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id,
search_beam_size=args.beam_size, length_penalty=args.length_penalty,
eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len,
len_vis_input=args.len_vis_input)
# unused
# model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model,
# max_position_embeddings=args.max_position_embeddings, config_path=args.config_path,
# state_dict={}, args=args, num_labels=cls_num_labels,
# type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id,
# search_beam_size=args.beam_size, length_penalty=args.length_penalty,
# eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
# forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len,
# len_vis_input=args.len_vis_input)

max_a, max_b, max_c, max_d = [], [], [], []
# for epoch_itr in range(10,int(args.model_recover_path.split('.')[-2])+1):
Expand All @@ -203,7 +204,7 @@ def main():
max_position_embeddings=args.max_position_embeddings, config_path=args.config_path,
state_dict=model_recover, args=args, num_labels=cls_num_labels,
type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id,
search_beam_size=args.beam_size, length_penalty=args.length_penalty,
length_penalty=args.length_penalty,
eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len,
len_vis_input=args.len_vis_input)
Expand Down
Expand Up @@ -886,12 +886,11 @@ def forward(self, input_id, input_imgs, vis_pe, token_type_ids): # img_embed_ou
)

if self.args.img_postion:
position_ids = torch.arange(seq_len, dtype=torch.long).cuda()
position_ids = position_ids.unsqueeze(0).expand(bsz, seq_len)
position_ids = torch.tensor([0, seq_len-1], dtype=torch.long).cuda()
position_ids = position_ids.unsqueeze(0).expand(bsz, 2)
position_embeddings = self.position_embeddings(position_ids)
pos_vis_embeddings = self.position_embeddings(vis_pe)
token_position_embeddings = torch.cat((position_embeddings[:,:1], pos_vis_embeddings, position_embeddings[:,self.args.len_vis_input+1:]), dim=1)

token_position_embeddings = torch.cat((position_embeddings[:,:1], pos_vis_embeddings, position_embeddings[:,-1:]), dim=1)
embeddings = token_embeddings + token_position_embeddings + token_type_embeddings # should be tensor
else:
embeddings = token_embeddings + token_type_embeddings # should be tensor
Expand Down

0 comments on commit 9c258c7

Please sign in to comment.