In [1]:
from pathlib import Path
import rich
import torch
import transformers

DIR_PATH = Path("/home/mila/g/gagnonju/IteratedDecoding/jobs/tmp/")

pegasus_names = [
    "google/pegasus-pubmed", 
    "google/pegasus-arxiv",
]

bigbird_names = [
    "google/bigbird-pegasus-large-pubmed",
    "google/bigbird-pegasus-large-arxiv",
]

active_model_name = "google/bigbird-pegasus-large-pubmed"

# By default encoder-attention is `block_sparse` with num_random_blocks=3, block_size=64
tokenizer = transformers.AutoTokenizer.from_pretrained(active_model_name)

model = transformers.BigBirdPegasusForConditionalGeneration.from_pretrained(active_model_name)
if torch.cuda.is_available():
    model.to("cuda")

In [2]:
# https://pubmed.ncbi.nlm.nih.gov/30426489/
text = (DIR_PATH / "advances_breast_cancer.txt").read_text().replace("## ", "")
tmp = tokenizer(text, return_tensors='pt')
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
if torch.cuda.is_available():
    inputs.to("cuda")

Token indices sequence length is longer than the specified maximum sequence length for this model (6311 > 4096). Running this sequence through the model will result in indexing errors


1


In [6]:
min_length = 200
max_length = 256
num_beams = 12
num_beam_groups = 12
num_return_sequences = 4

method = "gbs_generate"

if method == "beam_search":
    predictions_tok = model.generate(
        **inputs, 
        repetition_penalty=3., 
        min_length=min_length, 
        num_beams=num_beams,
        max_length=max_length,
        temperature=2.0,
        do_sample=True,
        top_k=num_beams,
        num_return_sequences=num_return_sequences,
        early_stopping=True,
    )

elif method == "gbs_generate":
    predictions_tok = model.generate(
        **inputs, 
        repetition_penalty=3., 
        diversity_penalty=1000.,
        min_length=min_length, 
        num_beams=num_beams,
        max_length=max_length,
        num_beam_groups=num_beam_groups,
        num_return_sequences=num_return_sequences,
        early_stopping=True,
    )
    
elif method == "group_beam_search":
    scorer = transformers.BeamSearchScorer(
        batch_size=1,
        max_length=max_length,
        num_beams=num_beams,
        device="cuda" if torch.cuda.is_available() else "cpu",
        length_penalty=1.0,
        num_beam_hyps_to_keep=num_return_sequences,
        num_beam_groups=num_beam_groups,
    )

    prediction_tok = model.group_beam_search(
        **inputs, 
        repetition_penalty=2., 
        min_length=max_length, 
        diversity_penalty=0.75, 
        beam_scorer=scorer,
    )

else:
    raise ValueError(f"Unknown method: {method}")

for prediction_tok in predictions_tok:
    print(len(prediction_tok))

predictions = tokenizer.batch_decode(
    predictions_tok, truncate=False,
)




232
232
232
232


In [8]:
print(f"{len(predictions) = }")
for i, prediction in enumerate(predictions):
    one = prediction.replace("abstract", "[bold]Abstract:[/]\n"
                    ).replace("<n>", ""
                    ).replace("<s>", ""
                    ).replace("<pad>", ""
                    ).replace("</s>", ""
                    ).replace(" - ", "-")

    two = [x.strip().capitalize() for x in one.split(".") if x]
    three = ". ".join(two) + "."

    rich.print(f"[bold]Attempt #{i}[/]: \n{three}")

len(predictions) = 4


In [21]:
help(model.generate)

Help on method generate in module transformers.generation_utils:

generate(input_ids: Union[torch.LongTensor, NoneType] = None, max_length: Union[int, NoneType] = None, min_length: Union[int, NoneType] = None, do_sample: Union[bool, NoneType] = None, early_stopping: Union[bool, NoneType] = None, num_beams: Union[int, NoneType] = None, temperature: Union[float, NoneType] = None, top_k: Union[int, NoneType] = None, top_p: Union[float, NoneType] = None, repetition_penalty: Union[float, NoneType] = None, bad_words_ids: Union[Iterable[int], NoneType] = None, bos_token_id: Union[int, NoneType] = None, pad_token_id: Union[int, NoneType] = None, eos_token_id: Union[int, NoneType] = None, length_penalty: Union[float, NoneType] = None, no_repeat_ngram_size: Union[int, NoneType] = None, encoder_no_repeat_ngram_size: Union[int, NoneType] = None, num_return_sequences: Union[int, NoneType] = None, max_time: Union[float, NoneType] = None, max_new_tokens: Union[int, NoneType] = None, decoder_start_tok