### Prepare Model

In [1]:
# pip install transformers bert-extractive-summarizer
# summarizer, load Chinese model
from summarizer import Summarizer
from summarizer.summary_processor import SentenceHandler
from spacy.lang.zh import Chinese
from transformers import AutoModel, AutoTokenizer, AutoConfig

# Load model, model config and tokenizer via Transformers
modelName = "hfl/chinese-roberta-wwm-ext"
custom_config = AutoConfig.from_pretrained(modelName)
custom_config.output_hidden_states=True
custom_tokenizer = AutoTokenizer.from_pretrained(modelName)
custom_model = AutoModel.from_pretrained(modelName, config=custom_config)

model = Summarizer(
    custom_model=custom_model, 
    custom_tokenizer=custom_tokenizer,
    sentence_handler = SentenceHandler(language=Chinese)
    )

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Prepare Data

In [2]:
import pandas as pd

data = pd.read_csv('data/testA.csv')
facts = data.fact.tolist()

facts[0]

'罪犯吴智信，男，1990年xx月xx日出生，汉族，小学文化，农民，原户籍所在地广西陆川县平乐镇六凤村18队32号，非累犯，现在广西贵港监狱服刑，以被告人吴智信犯抢劫、盗窃罪，判处有期徒刑二十年，剥夺政治权利四年，并处罚金人民币210000元；该犯不服，提出上诉，经玉林市中级人民法院于2013年xx月xx日作出（2013）玉中少刑—终字第20号刑事判决，判处有期徒刑十五年，剥夺政治权利一年，并处罚金人民币190000元，刑期自2011年xx月xx日起至2026年xx月xx日止。2013年xx月xx日送广西区贵港监狱服刑。服刑期间刑期无变动。执行机关于2015年xx月xx日提出减刑建议书报送本院审理。本院立案后，依法组成合议庭进行审理，现已审理终结。 执行机关认为，罪犯吴智信在服刑期间，能遵守监规纪律，确有悔改表现，建议对该犯减刑一年，并提供了罪犯减刑审核表、计分考核手册，奖励审批表等证据证实。 贵港市人民检察院对罪犯吴智信本次减刑无异议。 经审理查明，罪犯吴智信自在服刑期间，能认罪悔罪；认真遵守法律法规及监规，接受教育改造；积极参加思想、文化、职业技术教育；积极参加劳动，努力完成劳动任务。获得2014年度监狱劳动能手（已折算成奖励分），截至2015年5月止，累计奖励分85.33分。上述事实有生效裁判文书、执行通知书、罪犯减刑审核表、改造鉴定表、奖惩审批表、计分考核手册等证据证实。 另查明，该犯至今未缴纳原判罚金190000元。 本院认为，罪犯吴智信在刑罚执行期间，能认罪服法，具备了法定的确有悔改表现的事实，符合减刑的法定条件。执行机关所提请减刑的建议，符合法律规定，本院根据该犯未履行原判财产刑的事实及其改造表现事实酌情予以减刑。'

### Summarize

In [3]:
def summarize(text, output_length=510):
    if len(text) < output_length:
        return text
    
    min_len = 0x7fffffff
    for i in range(10, 2, -1):
        result = model(text, num_sentences=i)
        if len(result) <= 1.1 * output_length:
            return result
        elif len(result) < min_len:
            min_len = len(result)
            min_result = result
            
    return min_result if 0.5 * output_length <= min_len <= output_length else text[-(output_length+1):-1]

In [4]:
from tqdm.autonotebook import tqdm
from warnings import simplefilter
from sklearn.exceptions import ConvergenceWarning
simplefilter("ignore", category=ConvergenceWarning)

summ_facts = []
error_idx = []

for i in tqdm(range(len(facts))):
    try:
        summ_facts.append(summarize(facts[i]))
    except:
        error_idx.append(i)
        if len(facts[i]) > 510:
            summ_facts.append(facts[i][-511:-1])
        else:
            summ_facts.append(facts[i])

  0%|          | 0/25001 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### Save to csv

In [5]:
data['summary'] = summ_facts

data.summary.str.len().describe()

count    25001.000000
mean       450.615735
std         68.388158
min        166.000000
25%        402.000000
50%        459.000000
75%        505.000000
max        561.000000
Name: summary, dtype: float64

In [6]:
data.to_csv('data/testA_bertsum.csv', index=False)