In [9]:
import torch
from transformers import AutoTokenizer, BertForTokenClassification
import math

model_path = "tim1900/bert-chunker-3"

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side="right",
    model_max_length=255,
    trust_remote_code=True,
)

device = "cpu"  # or 'cuda'

model = BertForTokenClassification.from_pretrained(
    model_path,
).to(device)

def chunk_text(model, text, tokenizer, prob_threshold=0.5):
    MAX_TOKENS = 255
    tokens = tokenizer(text, return_tensors="pt", truncation=False)
    input_ids = tokens["input_ids"]
    attention_mask = tokens["attention_mask"][:, 0:MAX_TOKENS]
    attention_mask = attention_mask.to(model.device)
    # add special tokens for BERT use
    CLS = input_ids[:, 0].unsqueeze(0)
    SEP = input_ids[:, -1].unsqueeze(0)
    input_ids = input_ids[:, 1:-1]
    # chunking
    model.eval()
    split_str_poses = []
    token_pos = []
    windows_start = 0
    windows_end = 0
    logits_threshold = math.log(1 / prob_threshold - 1)
    print(f"Processing {input_ids.shape[1]} tokens...")
    # slide context window chunking
    while windows_end <= input_ids.shape[1]:
        windows_end = windows_start + MAX_TOKENS - 2

        ids = torch.cat((CLS, input_ids[:, windows_start:windows_end], SEP), 1)

        ids = ids.to(model.device)

        output = model(
            input_ids=ids,
            attention_mask=torch.ones(1, ids.shape[1], device=model.device),
        )
        logits = output["logits"][:, 1:-1, :]
        chunk_decision = logits[:, :, 1] > (logits[:, :, 0] - logits_threshold)
        greater_rows_indices = torch.where(chunk_decision)[1].tolist()

        # find the split position
        if len(greater_rows_indices) > 0 and (
            not (greater_rows_indices[0] == 0 and len(greater_rows_indices) == 1)
        ):

            split_str_pos = [
                tokens.token_to_chars(sp + windows_start + 1).start
                for sp in greater_rows_indices
                if sp > 0
            ]
            token_pos += [
                sp + windows_start + 1 for sp in greater_rows_indices if sp > 0
            ]
            split_str_poses += split_str_pos

            windows_start = greater_rows_indices[-1] + windows_start

        else:

            windows_start = windows_end
    # generate final chunk
    substrings = [
        text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses + [len(text)])
    ]
    token_pos = [0] + token_pos
    return substrings, token_pos


# chunking code docs
print("\n>>>>>>>>> Chunking code docs...")
doc = r"""
APPENDIX GOVERNANCE SOCIAL RESPONSIBILITY ENVIRONMENTAL RESPONSIBILITY Al A.2 Key Performance Indicators KPIs A.2.1 OPERATIONS DATA an ee 1,598,606 444,057 1,503,763 417,714 1,557,968 432,769 116,004 68,551 773 EE Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Intensity Total energy consumption, gigajoule GJ megawatt hour MWh Total fuel consumption from nonrenewable sources, gigajoule GJ megawatt hour MWh 404,021 112,228 360,169 100,048 417,613 Natural gas, gigajoule GJ megawatt hour MWh 235,616 65,449 203,999 56,667 246,783 Gasoline, gigajoule GJ megawatt hour MWh 2,376 660 1,860 517 2,784 Diesel, gigajoule GJ megawatt hour MWh 110,585 30,718 113,134 31,426 118,590 32,942 e Propane, gigajoule GJ megawatt hour MWh 7,499 2,069 6,520 1,811 7,532 2,092 e LPG, gigajoule GJ megawatt hour MWh 27,463 7,628 25,309 7,030 28,431 7,898 e LNG, gigajoule GJ megawatt hour MWh 101 28 0 0 0 0 e Jet fuel, gigajoule GJ megawatt hour MWh 20,227 5,619 9,082 2,523 13,458 3,738 oak eusdioule all mecewathourikiWh methyl acetylene-propadiene 205 57 267 7A 35 10 Total fuel consumption from renewable sources, gigajoule GJ megawatt hour MWh 0 0 0 0 0 0 Indirect energy usage, gigajoule GJ megawatt hour MWh 1,194,585 331,829 1,143,594 317,665 1,140,355 316,766 e Electricity consumption, gigajoule GJ megawatt hour MWh 1,192,249 331,181 1,142,365 317,324 1,138,859 316,350 e Heating consumption, gigajoule GJ megawatt hour MWh 2,336 649 1,229 341 1,496 416 Cooling consumption, gigajoule GJ megawatt hour MWh 0 0 0 0 0 0 e Steam consumption, gigajoule GJ megawatt hour MWh 0 0 0 0 0 0 Electricity from renewable sources, gigajoule GJ megawatt hour MWh 79,457 22,071 87,075 24,188 101,319 28,144 Electricity from nonrenewable sources, gigajoule GJ megawatt hour MWh 1,112,792 309,109 1,055,290 293,136 1,037,540 288,206 Renewable electricity share of total electricity, percent 6.66% 7.62% 8.90% Emissions avoided due to purchased renewable electricity, metric tons of CO, e 7,982 7,917 9,102 Total energy use normalized per $ million annual turnover, gigajoule $1M GJ $1M 191.56 53.21 178.26 49.52 181.43 50.40 Category M4 2019 2020 clip GHG Emissions Total GHG emissions Scope 1, metric tons of CO, e 71,740 63,037 54,108 GHG Emissions Total GHG emissions Scope 2 location-based , metric tons of CO, e 127,229 117,376 117,329 GHG Emissions Total GHG emissions Scope 2 market-based , metric tons of CO, e 119,247 109,459 108,221 GHG Emissions Total Scope 1 and 2 GHG emissions location-based , metric tons of CO, e 198,969 180,413 171,431 GHG Emissions Total Scope 1 and 2 GHG emissions market based , metric tons of CO, e 190,987 172,496 162,329 37 Investing in Our Future 2022 Sustainability Report"""
# Chunk the text. The prob_threshold should be between (0, 1). The lower it is, the more chunks will be generated.
# Therefore adjust it to your need, when prob_threshold is small like 0.000001, each token is one chunk,
# when it is set to 1, the whole text will be one chunk.
chunks, token_pos = chunk_text(model, doc, tokenizer, prob_threshold=0.1)

# print chunks
for i, (c, t) in enumerate(zip(chunks, token_pos)):
    print(f"-----chunk: {i}----token_idx: {t}--------")
    print(c)

Token indices sequence length is longer than the specified maximum sequence length for this model (934 > 255). Running this sequence through the model will result in indexing errors



>>>>>>>>> Chunking code docs...
Processing 932 tokens...
-----chunk: 0----token_idx: 0--------

APPENDIX GOVERNANCE SOCIAL RESPONSIBILITY ENVIRONMENTAL RESPONSIBILITY Al A.2 Key Performance Indicators KPIs A.2.1 OPERATIONS DATA an ee 1,598,606 444,057 1,503,763 417,714 1,557,968 432,769 116,004 68,551 773 EE Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Energy Intensity Total energy consumption, gigajoule GJ megawatt hour MWh Total fuel consumption from nonrenewable sources, gigajoule GJ megawatt hour MWh 404,021 112,228 360,169 100,048 417,613 Natural gas, gigajoule GJ megawatt hour MWh 235,616 65,449 203,999 56,667 246,783 Gasoline, gigajoule GJ megawatt hour MWh 2,376 660 1,860 517 2,784 Diesel, gigajoule GJ megawatt hour MWh 110,585 30,718 113,134 31,426 118,590 32,942 e Propane, gigajoule GJ megawatt hour MWh 7,499 2,069 6,520 1,811 7,532 2,092 e LPG, gigajoule GJ megawatt hour MWh 27,463 7,628 25,309