In [13]:
import torch
from transformers import BartForConditionalGeneration, BartTokenizer

class SimpleBART(torch.nn.Module):
    def __init__(self):
        super(SimpleBART, self).__init__()

        self.bart = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

    def forward(self, input_ids, attention_mask):
        return self.bart(input_ids=input_ids, attention_mask=attention_mask)
    
    def get_sentences_from_beam(self, beam_output):
        sentences = []
        for sequence in beam_output:
            decoded_sentence = self.tokenizer.decode(sequence, skip_special_tokens=True)
            sentences.append(decoded_sentence)
        return sentences

model = SimpleBART()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()

In [14]:
english_text = input("Enter the source text: ")

print(f"> {english_text}")

input_ids = model.tokenizer.encode(english_text, return_tensors="pt")
attention_mask = torch.ones(input_ids.shape)

with torch.no_grad():
    beam_output = model.bart.generate(
        input_ids,
        num_beams=4,
        max_length=200,
        num_return_sequences=4,
        eos_token_id=model.tokenizer.eos_token_id
    )
    sentences = model.get_sentences_from_beam(beam_output)

for sentence in sentences:
    print(sentence)

> skirt hundred kindergarten aeroplane arrival hot read
skirt hundred kindergarten aeroplane arrival hot read
skirt thousand kindergarten aeroplane arrival hot read
skirt hundred kindergarten aeroplane arrivals hot read
skirt hundred kindergarten aeroplane delivery hot read
