In [1]:
import json
from transformers import AutoTokenizer, AutoConfig
from tqdm import tqdm
import matplotlib.pyplot as plt
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
question = "Summarize the above article in 2 sentence."
prompt = "[INST]{article}\n{instruction}[/INST]"

In [3]:
model_name = "Yukang/LongAlpaca-7B"
cache_dir = "../cache"
context_size = 32768

In [4]:
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)

orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and context_size > orig_ctx_len:
    scaling_factor = float(math.ceil(context_size / orig_ctx_len))
    config.rope_scaling = {"type": "linear", "factor": scaling_factor}

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    model_max_length=context_size if context_size > orig_ctx_len else orig_ctx_len,
    padding_side="right",
    use_fast=False,
)

tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 15.6MB/s]
added_tokens.json: 100%|██████████| 21.0/21.0 [00:00<00:00, 94.1kB/s]
special_tokens_map.json: 100%|██████████| 438/438 [00:00<00:00, 1.10MB/s]
tokenizer.json: 100%|██████████| 1.84M/1.84M [00:00<00:00, 25.7MB/s]


In [10]:
def num_tokens_from_string(string):
    return len(tokenizer.encode(string))

def calculate_length(path):
    with open(path, 'r') as f:
        data = json.load(f)
    source_lengths = []
    target_lengths = []
    for meeting in tqdm(data):
        source_lengths.append(num_tokens_from_string(prompt.format_map({"article": meeting['source'], "instruction": question})))
        target_lengths.append(num_tokens_from_string(meeting['summary']))
    return source_lengths, target_lengths

In [6]:
def min_max_avg(numbers):
    min_val = min(numbers)
    max_val = max(numbers)
    avg_val = sum(numbers) / len(numbers)
    return min_val, max_val, avg_val

In [7]:
def plot_lengths(data0, data1=None):
    plt.hist(data0, bins=100, alpha=0.5, color='blue', label='Source')
    if data1 is not None:
        plt.hist(data1, bins=100, alpha=0.5, color='red', label='Target')
    plt.show()

In [23]:
def show_stat(stat):
    for task in stat:
        print(f'==================== {task} ====================')
        source_lengths, target_lengths = stat[task]
        total_lengths = [source_lengths[i] + target_lengths[i] for i in range(len(source_lengths))]

        # source_min, source_max, source_avg = min_max_avg(source_lengths)
        # print(f'{task}-source: min={source_min}, max={source_max}, avg={source_avg}')

        # target_min, target_max, target_avg = min_max_avg(target_lengths)
        # print(f'{task}-target: min={target_min}, max={target_max}, avg={target_avg}')

        total_min, total_max, total_avg = min_max_avg(total_lengths)
        print(f'{task}-total: min={total_min}, max={total_max}, avg={total_avg}')
        
        length_4k = 0
        length_8k = 0
        length_16k = 0
        length_32k = 0
        length_above = 0
        for l in total_lengths:
            if l <= 4000:
                length_4k += 1
            elif l <= 8000:
                length_8k += 1
            elif l <= 16000:
                length_16k += 1
            elif l <= 32000:
                length_32k += 1
            else:
                length_above += 1
        print(f'{task}-total: 4k={length_4k}, 8k={length_8k}, 16k={length_16k}, 32k={length_32k}, above={length_above}')

        # plot_lengths(source_lengths, target_lengths)
        # plot_lengths(total_lengths)

In [18]:
stats = dict()
for task in ['test', 'validation', 'train']:
    path = f'{task}_segment.json'
    stats[task] = calculate_length(path)

100%|██████████| 862/862 [00:53<00:00, 16.17it/s]
100%|██████████| 861/861 [00:47<00:00, 18.06it/s]
100%|██████████| 5169/5169 [05:49<00:00, 14.78it/s]


In [24]:
show_stat(stats)

test-total: min=236, max=96605, avg=4205.548723897912
test-total: 4k=631, 8k=123, 16k=69, 32k=27, above=12
validation-total: min=215, max=87276, avg=4162.984901277584
validation-total: 4k=648, 8k=95, 16k=72, 32k=33, above=13
train-total: min=249, max=100985, avg=4638.439736893016
train-total: 4k=3666, 8k=730, 16k=436, 32k=240, above=97
