In [1]:
from utils import BlockLenAdjuster
import torch
from pytorch_transformers import BertTokenizer
from models.model_builder import AbsSummarizer
from models.predictor import build_predictor
from others.logging import logger, init_logger
from models.data_loader import load_text
import os
import spacy

In [15]:
class SmartNews:
    def __init__(self, config):
        self.config = config 
        self.adjuster = BlockLenAdjuster(input_dir=config.input_path)
        logger.info('Loading checkpoint from %s' % config.test_from)
        self.device = "cpu" if config.visible_gpus == '-1' else "cuda"
    
        checkpoint = torch.load(config.test_from, map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        
        model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size',
               'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval']
        for k in opt.keys():
            if (k in model_flags):
                setattr(config, k, opt[k])
        print(config)

        model = AbsSummarizer(self.config, self.device, checkpoint)
        model.eval()
   
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=config.temp_dir)
        symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
                   'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}
        self.predictor = build_predictor(config, tokenizer, symbols, model, logger)
        self.nlp = spacy.load('en_core_web_sm')
        
    def summarize(self, file_name):
        self.adjuster.compact_paragraph(file_name)
        self.config.text_src =  os.path.join(self.config.input_path, f'proc_{file_name}') 
        self.config.result_path = os.path.join(self.config.output_path, f'result_{file_name}')
        test_iter = load_text(self.config,  self.config.text_src, '', self.device)
        self.predictor.translate(test_iter, -1)
        return self.diagnostic(self.config)
        
    def diagnostic(self, config):
        
        with open(config.text_src, 'r') as f:
            in_file_content = f.read()
            doc_input = self.nlp(in_file_content)
        with open(config.result_path,'r') as f:
            out_file_content = f.read()
            doc_output = self.nlp(out_file_content)
            
        input_token = len(doc_input)
        output_token = len(doc_output)
        compress_ratio = output_token / input_token 
        return {
            'input_token': input_token,
            'output_token': output_token,
            'compress_ratio': compress_ratio
        }
        
        

In [16]:
class Config:
    def __init__(self):
        self.task = 'abs'
        self.encoder = 'bert'
        self.mode = 'test_text'
        self.bert_data_path = '../bert_data_new/cnndm'
        self.model_path = '../models/'
        self.result_path = ''
        self.temp_dir = '../temp'
        self.text_src = ''
        self.text_tgt = ''
        
        self.batch_size = 140
        self.test_batch_size = 200
        self.max_ndocs_in_batch = 6
        
        self.max_pos = 512
        self.use_interval = True
        self.large = False
        self.load_from_extractive = ''
        
        self.sep_optim = False
        self.lr_bert = 2e-3
        self.lr_dec = 2e-3
        self.use_bert_emb = False
        
        self.share_emb = False
        self.finetune_bert = True
        self.dec_dropout = 0.2
        self.dec_layers = 6
        self.dec_hidden_size = 768
        self.dec_heads = 8
        self.dec_ff_size = 2048
        self.enc_hidden_size = 512
        self.enc_ff_size = 512
        self.enc_dropout = 0.2
        self.enc_layers = 6
        
        self.ext_dropout = 0.2
        self.ext_layers = 2
        self.ext_hidden_size =768
        self.ext_heads = 8 
        self.ext_ff_size = 2048
        
        self.label_smoothing = 0.1
        self.generator_shard_size = 32
        self.alpha = 0.6
        self.beam_size = 5
        self.min_length = 15
        self.max_length = 150
        self.max_tgt_len = 140
        
        self.param_init = 0
        self.param_init_glorot = True
        self.optim =  'adam'
        self.lr = 1 
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.warmup_steps = 8000
        self.warmup_steps_bert = 8000
        self.warmup_steps_dec = 8000
        self.max_grad_norm = 0
        
        self.save_checkpoint_steps = 5
        self.accum_count = 1
        self.report_every = 1
        self.train_steps = 1000
        self.recall_eval = False
        
        self.visible_gpus = '0'
        self.gpu_ranks = '0'
        self.log_file = './logs/pred.log'
        self.seed = 666
        
        self.test_all = False
        self.test_from = ''
        self.test_start_from = -1
        
        self.train_from = ''
        self.report_rouge = True
        self.block_trigram = True
        
        self.input_path = '../input_raw_text/'
        self.output_path = '../pred/'
        
        
            

In [17]:
c = Config()
c.test_from = '../models/model_step_148000.pt'

In [18]:
model = SmartNews(c)

<__main__.Config object at 0x7f28cd487438>


In [19]:
r = model.summarize('amzn.txt')

8


In [20]:
r

{'input_token': 464,
 'output_token': 185,
 'compress_ratio': 0.39870689655172414}

In [21]:
dir(c.parser)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_action_groups',
 '_actions',
 '_add_action',
 '_add_container_actions',
 '_check_conflict',
 '_check_value',
 '_defaults',
 '_get_args',
 '_get_formatter',
 '_get_handler',
 '_get_kwargs',
 '_get_nargs_pattern',
 '_get_option_tuples',
 '_get_optional_actions',
 '_get_optional_kwargs',
 '_get_positional_actions',
 '_get_positional_kwargs',
 '_get_value',
 '_get_values',
 '_handle_conflict_error',
 '_handle_conflict_resolve',
 '_has_negative_number_optionals',
 '_match_argument',
 '_match_arguments_partial',
 '_mutually_exclusive_groups',
 '_negative_number_matcher',
 '_option_string_actions',
 '_optionals',
 '_parse_know