In [1]:
import os
import time

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model_name = "jordiclive/flan-t5-11b-summarizer-filtered-1.5-epoch"
tokenizer = AutoTokenizer.from_pretrained(model_name)
#kwargs = dict(device_map="balanced_low_0", torch_dtype=torch.bfloat16)

t_start = time.time()
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = torch.cuda.device_count()
target_length = 150
max_source_length = 512

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
# Prompts should be formatted with a colon at the end so that the input to the model is formatted as
# e.g. "Summarize the following: \n\n  <input text>"
example_prompts = {
    "social": "Produce a short summary of the following social media post:",
    "ten": "Summarize the following article in 10-20 words:",
    "5": "Summarize the following article in 0-5 words:",
    "100": "Summarize the following article in about 100 words:",
    "summary": "Write a ~ 100 word summary of the following text:",
    "short": "Provide a short summary of the following article:",
}


def generate(inputs, max_source_length=512, summarization_type=None, prompt=None):
    """returns a list of zipped inputs, outputs and number of new tokens"""

    if prompt is not None:
        inputs = [f"{prompt.strip()} \n\n {i.strip()}" for i in inputs]
    if summarization_type is not None:
        inputs = [
            f"{example_prompts[summarization_type].strip()} \n\n {i.strip()}"
            for i in inputs
        ]
    if summarization_type is None and prompt is None:
        inputs = [f"Summarize the following: \n\n {i.strip()}" for i in inputs]
    input_tokens = tokenizer.batch_encode_plus(
        inputs,
        max_length=max_source_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to("cpu")

    outputs = model.generate(
        **input_tokens,
        use_cache=True,
        num_beams=5,
        min_length=5,
        max_new_tokens=target_length,
        no_repeat_ngram_size=3,
    )

    input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids]
    output_tokens_lengths = [x.shape[0] for x in outputs]

    total_new_tokens = [
        o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)
    ]
    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    return inputs, outputs, total_new_tokens





  from .autonotebook import tqdm as notebook_tqdm


: 

: 