<a href="https://colab.research.google.com/github/GarettGazay/ai_projects/blob/master/Seq2Seq_Longformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import torch
from transformers import LEDTokenizer, LEDForConditionalGeneration

# define the model and tokenizer
model_name = "allenai/led-base-16384"
tokenizer = LEDTokenizer.from_pretrained(model_name)
model = LEDForConditionalGeneration.from_pretrained(model_name)



In [4]:
# define the input and target sequences
inputs = ['ride0', 'ride3', 'ride1', 'ride2']
targets = ['ride0', 'ride1', 'ride2', 'ride3']


In [5]:
# tokenize the input and target sequences
input_ids = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")
target_ids = tokenizer.batch_encode_plus(targets, padding=True, return_tensors="pt")


In [6]:
input_ids

{'input_ids': tensor([[    0, 23167,   288,     2],
        [    0, 23167,   246,     2],
        [    0, 23167,   134,     2],
        [    0, 23167,   176,     2]]), 'attention_mask': tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])}

In [7]:
target_ids

{'input_ids': tensor([[    0, 23167,   288,     2],
        [    0, 23167,   134,     2],
        [    0, 23167,   176,     2],
        [    0, 23167,   246,     2]]), 'attention_mask': tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])}

In [8]:
# retrieve the input and target token IDs and attention masks
encoder_input_ids = input_ids["input_ids"]
decoder_input_ids = target_ids["input_ids"]
decoder_attention_mask = target_ids["attention_mask"]

In [11]:
# generate output sequence
generated_ids = model.generate(
    input_ids=encoder_input_ids,
    attention_mask=input_ids["attention_mask"],
    decoder_start_token_id=tokenizer.pad_token_id,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
    use_cache=True,
    max_length=128,
    num_beams=1,
)


In [12]:
# decode output sequence
generated_sequence = tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)

print(generated_sequence)

['ride0', 'ride1', 'ride2', 'ride3']
