In [74]:
import argparse
import torch
from modelhelper import (
    init_logger,
    ModelHelper
)
import numpy as np
import os
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
class ClassificationHelper():
    def __init__(self, args):
        args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

        model_helper = ModelHelper()
        dict_modelset = model_helper.get_modelset(args.model_type)

        tokenizer = dict_modelset['tokenizer'].from_pretrained(
            args.model_name_or_path,
            do_lower_case=args.do_lower_case
        )        
        config = dict_modelset['config'].from_pretrained(args.model_name_or_path)   
        # self.model = dict_modelset['model'].from_pretrained(args.model_name_or_path)
        
        self.model = dict_modelset['model'](config)
        # print(self.model)
        self.model.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "pytorch_model.bin")))
                
        self.model.to(args.device)
        self.model_id2label = config.id2label
        self.model_label2id = {_label:_id for _id, _label in config.id2label.items()}
        self.args = args
        self.tokenizer = tokenizer


    def classifyList_decoding(self, list_text, top_n=3, decode_seq_len=5):
        list_result = []
        
        batch_encoding = self.tokenizer.batch_encode_plus(
            [(text_data, None) for text_data in list_text],
            max_length=self.args.max_seq_len,
            padding="max_length",
            add_special_tokens=True,
            truncation=True,
        )
        
        all_input_ids = torch.tensor([e for e in batch_encoding['input_ids']], dtype=torch.long)
        all_attention_mask = torch.tensor([e for e in batch_encoding['attention_mask']], dtype=torch.long)     
        all_token_type_ids = torch.tensor([e for e in batch_encoding['token_type_ids']], dtype=torch.long)
        all_decoder_inputs = torch.tensor([self.model_label2id.get('#') for _ in range(len(batch_encoding['input_ids']))], dtype=torch.long)

        tensor_dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_decoder_inputs)
        sequence_sampler = SequentialSampler(tensor_dataset)
        dataloader = DataLoader(tensor_dataset, sampler=sequence_sampler, batch_size=self.args.eval_batch_size)

        for batch in dataloader:
            self.model.eval()
            batch = tuple(t.to(self.args.device) for t in batch)
            
            list_logits, list_preds = [], []
            with torch.no_grad():
                inputs_1 = {
                    "encoder_input_ids": batch[0],                
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                }
                encoder_out = self.model.encode(**inputs_1)
                
                batch_size = batch[0].shape[0]
                decoder_hidden_states = None
                temp_input_ids = batch[3].view(-1, 1)
                feed_tensor = torch.unsqueeze(encoder_out[:, 0, :], 1)
                
                for i in range(self.args.decode_seq_len):
                    inputs_2 = {
                        "encoder_out": encoder_out,
                        "decoder_input_ids": temp_input_ids,
                        "decoder_feeds":feed_tensor, # [batch_size, 1, hidden_size],
                        "decoder_hidden_states": decoder_hidden_states
                    }
                    logits, decoder_hidden_states, feed_tensor = self.model.decode(**inputs_2)

                    temp_logits = torch.nn.functional.softmax(logits, dim=2) # [batch_size, decoder_seqlen, num_labels]
                    temp_preds = torch.argmax(logits, dim=2) # [batch_size, decoder_seqlen]

                    list_logits.append(temp_logits[:,-1, :])
                    list_preds.append(torch.unsqueeze(temp_preds[:,-1], 1))

                    temp_input_ids = temp_preds[:,-1].view(-1, 1)
                    
            total_logits = torch.cat(list_logits, dim=1)
            total_pred = torch.cat(list_preds, dim=1)
            
            np_logits = total_logits.detach().cpu().numpy().reshape(-1, decode_seq_len, len(self.model_id2label))
            np_preds = np.argsort(np_logits, axis=2)
            
            topn_pred = np_preds[:,:,-top_n:]
            topn_prob = np.take_along_axis(np_logits, topn_pred, axis=2)

            for list_pred, list_prob in zip(topn_pred, topn_prob):
                dict_result = {}
                temp_pred = ''
                for decode_idx in range(decode_seq_len):
                    list_pred_ele = list_pred[decode_idx]
                    list_prob_ele = list_prob[decode_idx]     
                    print(f"list_pred_ele : {list_pred_ele}")
                    print(f"list_prob_ele : {list_prob_ele}")
                    for cnt, (_pred, _prob) in enumerate(zip(list_pred_ele, list_prob_ele)):
                        #print(f"pred1:  {_pred}")
                        _pred = self.model_id2label.get(int(_pred))
                        #print(f"pred2:  {_pred}")
                        topn_idx = top_n-cnt
                        print(f'{decode_idx+1}_{topn_idx}_pred : {_pred}')
                        dict_result.update({f'{decode_idx+1}_{topn_idx}_pred':temp_pred + str(_pred), f'{decode_idx+1}_{topn_idx}_prob':str(_prob)})
                    
                    temp_pred = temp_pred + str(list_pred_ele[0])
                    #temp_pred = temp_pred + str(_pred)
                list_result.append(dict_result)

        return list_result  # [{'1_pred':'', '1_prob':'', '2_pred':'', '2_prob':''}, ...]

In [75]:
import argparse
import logging
import os

import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from fastprogress.fastprogress import progress_bar
from attrdict import AttrDict

from modelhelper import (
    init_logger,
    ModelHelper
)

In [76]:
parser = argparse.ArgumentParser()

parser.add_argument("--input_file", type=str, default="")
parser.add_argument("--output_file", type=str, default="")
parser.add_argument("--top_n", type=int, default=3)
parser.add_argument("--decode_seq_len", type=int, default=5)

parser.add_argument("--write_mode", type=int, default=2)

parser.add_argument("--buff_size", type=int, default=1024)

# parser.add_argument("--task", type=str, default="")
parser.add_argument("--model_type", type=str, default="bert_generation")
parser.add_argument("--model_name_or_path", type=str, default="/docker/model_results/v_c/patent_bert_large_generation_v_c/checkpoint-271852")

parser.add_argument("--max_seq_len", type=int, default=512)
parser.add_argument("--eval_batch_size", type=int, default=32)

parser.add_argument("--do_lower_case", action='store_true', help="")
parser.add_argument("--no_cuda", action='store_true', help="")

parser.add_argument("--label_embedding", action='store_true', help="")
parser.add_argument("--multiclass", action='store_true', help="")    
parser.add_argument("--do_infer", action='store_true', help="")
args = parser.parse_args(args=[])

In [77]:
classification_helper = ClassificationHelper(args)   

In [78]:
list_temp_text = ["A pulse generator (1) that generates a signal by the rotation of the head drum in a video recorder that includes a rotating head drum partially wrapped by a magnetic tape. A frequency converter 6 is used to change the frequency of the signal generated by the pulse generator within a predetermined value range. In order to correct the head exchange time point, a generator (3) for generating a pulse corresponding to the frequency converted signal is generated. A memory 4 for storing signals generated by the generator is generated by the generator. The record features that the signal generated by the pulse generator 1 is changed by the generator 3 in the process of adjusting the pulse generator 1 to the generator 5, which generates a head-switching signa"]

In [79]:
list_key = []

for i in range(1, args.decode_seq_len+1, 1):
    for j in range(1, args.top_n+1, 1):
        list_key.append(f'{i}_{j}_pred')
        list_key.append(f'{i}_{j}_prob')

In [80]:
list_temp_result = classification_helper.classifyList_decoding(list_temp_text, args.top_n)

list_pred_ele : [7 3 6]
list_prob_ele : [6.8900976e-05 1.0151900e-03 9.9887222e-01]
1_3_pred
1_2_pred
1_1_pred
list_pred_ele : [ 2  9 10]
list_prob_ele : [1.6934413e-04 3.9739185e-03 9.9569017e-01]
2_3_pred
2_2_pred
2_1_pred
list_pred_ele : [10  3  2]
list_prob_ele : [1.8309740e-06 6.1403756e-05 9.9993551e-01]
3_3_pred
3_2_pred
3_1_pred
list_pred_ele : [5 3 2]
list_prob_ele : [2.3343149e-05 2.1766005e-02 9.7819978e-01]
4_3_pred
4_2_pred
4_1_pred
list_pred_ele : [10  3  2]
list_prob_ele : [8.58060594e-06 1.16560095e-05 9.99979138e-01]
5_3_pred
5_2_pred
5_1_pred


In [63]:
i=0
result_BUFF = '\t'.join([list_temp_result[i][key] for key in list_key])
result_BUFF2 = [(key, list_temp_result[i][key]) for key in list_key]

In [66]:
result_BUFF2

[('1_1_pred', '5'),
 ('1_1_prob', '0.9988722'),
 ('1_2_pred', '2'),
 ('1_2_prob', '0.00101519'),
 ('1_3_pred', '6'),
 ('1_3_prob', '6.8900976e-05'),
 ('2_1_pred', '59'),
 ('2_1_prob', '0.99569017'),
 ('2_2_pred', '58'),
 ('2_2_prob', '0.0039739185'),
 ('2_3_pred', '51'),
 ('2_3_prob', '0.00016934413'),
 ('3_1_pred', '591'),
 ('3_1_prob', '0.9999355'),
 ('3_2_pred', '592'),
 ('3_2_prob', '6.1403756e-05'),
 ('3_3_pred', '599'),
 ('3_3_prob', '1.830974e-06'),
 ('4_1_pred', '5911'),
 ('4_1_prob', '0.9781998'),
 ('4_2_pred', '5912'),
 ('4_2_prob', '0.021766005'),
 ('4_3_pred', '5914'),
 ('4_3_prob', '2.334315e-05'),
 ('5_1_pred', '59111'),
 ('5_1_prob', '0.99997914'),
 ('5_2_pred', '59112'),
 ('5_2_prob', '1.16560095e-05'),
 ('5_3_pred', '59119'),
 ('5_3_prob', '8.580606e-06')]

In [54]:
result_BUFF2[:10]

[('1_1_pred', '5'),
 ('1_1_prob', '0.9988722'),
 ('1_2_pred', '2'),
 ('1_2_prob', '0.00101519'),
 ('1_3_pred', '6'),
 ('1_3_prob', '6.8900976e-05'),
 ('2_1_pred', '79'),
 ('2_1_prob', '0.99569017'),
 ('2_2_pred', '78'),
 ('2_2_prob', '0.0039739185')]

In [9]:
res = []
for i in range(1,4):
    pred = ''
    score = ''
    for j in range(1,6):
        pred += list_temp_result[0][f"{j}_{i}_pred"][-1]
    res.append(pred)

In [13]:
sp_line[:11]

['5',
 '0.9988722',
 '2',
 '0.00101519',
 '6',
 '6.8900976e-05',
 '79',
 '0.99569017',
 '78',
 '0.0039739185',
 '71']

In [49]:
sp_line =result_BUFF.split('\t')
temp_pred = ''
for e in [2, 4, 6, 8, 10]:
    temp_pred += sp_line[e][-1]
    
temp_prob = 1        
for e in [3, 5, 7, 9, 11]:
    temp_prob = temp_prob * float(sp_line[e])

In [51]:
temp_pred

'26981'