diff --git a/downstream_task/report_generation_and_vqa/generation_decode.py b/downstream_task/report_generation_and_vqa/generation_decode.py index 19358ce..d587b64 100644 --- a/downstream_task/report_generation_and_vqa/generation_decode.py +++ b/downstream_task/report_generation_and_vqa/generation_decode.py @@ -16,12 +16,12 @@ import torchvision import csv from transformers import BertTokenizer -from pytorch_pretrained_bert.modeling_like_cxrbert import BertForSeq2SeqDecoder +from pytorch_pretrained_bert.model import BertForSeq2SeqDecoder -import sc.seq2seq_loader_itm as seq2seq_loader -from sc.bleu import language_eval_bleu -from misc.data_parallel import DataParallelImbalance -from sc.image_embedding import Img_patch_embedding, fully_sampling, random_sampling +# import sc.seq2seq_loader_itm as seq2seq_loader +from data_loader import Preprocess4Seq2seqDecoder +from bleu import language_eval_bleu +from data_parallel import DataParallelImbalance logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', @@ -148,7 +148,7 @@ def main(): bi_uni_pipeline = [] # def __init__(self, tokenizer, max_len, max_txt_length, new_segment_ids=False, mode="s2s", len_vis_input=None): - bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(tokenizer, args.max_seq_length, + bi_uni_pipeline.append(Preprocess4Seq2seqDecoder(tokenizer, args.max_seq_length, max_txt_length=args.max_txt_length, new_segment_ids=args.new_segment_ids, mode='s2s', len_vis_input=args.len_vis_input)) @@ -175,14 +175,15 @@ def main(): w_list.append(w) forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) - model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, - max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, - state_dict={}, args=args, num_labels=cls_num_labels, - type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, - search_beam_size=args.beam_size, length_penalty=args.length_penalty, - eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, - forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, - len_vis_input=args.len_vis_input) + # unused + # model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, + # max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, + # state_dict={}, args=args, num_labels=cls_num_labels, + # type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, + # search_beam_size=args.beam_size, length_penalty=args.length_penalty, + # eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, + # forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, + # len_vis_input=args.len_vis_input) max_a, max_b, max_c, max_d = [], [], [], [] # for epoch_itr in range(10,int(args.model_recover_path.split('.')[-2])+1): @@ -203,7 +204,7 @@ def main(): max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, args=args, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, - search_beam_size=args.beam_size, length_penalty=args.length_penalty, + length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, len_vis_input=args.len_vis_input) diff --git a/downstream_task/report_generation_and_vqa/pytorch_pretrained_bert/model.py b/downstream_task/report_generation_and_vqa/pytorch_pretrained_bert/model.py index 54f58fe..1a171b0 100644 --- a/downstream_task/report_generation_and_vqa/pytorch_pretrained_bert/model.py +++ b/downstream_task/report_generation_and_vqa/pytorch_pretrained_bert/model.py @@ -886,12 +886,11 @@ def forward(self, input_id, input_imgs, vis_pe, token_type_ids): # img_embed_ou ) if self.args.img_postion: - position_ids = torch.arange(seq_len, dtype=torch.long).cuda() - position_ids = position_ids.unsqueeze(0).expand(bsz, seq_len) + position_ids = torch.tensor([0, seq_len-1], dtype=torch.long).cuda() + position_ids = position_ids.unsqueeze(0).expand(bsz, 2) position_embeddings = self.position_embeddings(position_ids) pos_vis_embeddings = self.position_embeddings(vis_pe) - token_position_embeddings = torch.cat((position_embeddings[:,:1], pos_vis_embeddings, position_embeddings[:,self.args.len_vis_input+1:]), dim=1) - + token_position_embeddings = torch.cat((position_embeddings[:,:1], pos_vis_embeddings, position_embeddings[:,-1:]), dim=1) embeddings = token_embeddings + token_position_embeddings + token_type_embeddings # should be tensor else: embeddings = token_embeddings + token_type_embeddings # should be tensor