In [1]:
%env CUDA_VISIBLE_DEVICES=2

env: CUDA_VISIBLE_DEVICES=2


In [8]:
import torch
from transformers import BartTokenizer, BartTokenizerFast, BartForConditionalGeneration, BartConfig
import textwrap

In [3]:
DEVICE = torch.device('cuda')

In [4]:
configuration = BartConfig()

In [5]:
configuration.from_pretrained("facebook/bart-base")

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

BartConfig {
  "_name_or_path": "bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_type": "bar

In [6]:
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base").to(DEVICE)

Downloading pytorch_model.bin:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [9]:
tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [None]:
model.forward()

In [10]:
def bold(text):
    return '\033[1m' + text + '\033[0m'
  
def generate_text(text, tokenizer, model, num_beams=5, num_sentences=5,
                   min_length=30, max_length=100, early_stopping=True):
    print(bold("Input Text:"))
    print(textwrap.fill(text, width=70), end='\n\n')

    tokenized_text = tokenizer.encode(text, return_tensors="pt").to(DEVICE)

    # generated texts
    gen_text_ids = model.generate(tokenized_text,
                              num_beams=num_beams,
                              num_return_sequences=num_sentences,
                              min_length=min_length,
                              max_length=max_length,
                              early_stopping=early_stopping)

    print(bold("Generated texts:\n"))
    for i, beam_output in enumerate(gen_text_ids):
        output = tokenizer.decode(beam_output, skip_special_tokens=True)
        print(bold(f"Generated text {i}:"))
        print(textwrap.fill(output, width=70), end='\n\n')

In [34]:
tokenizer.encode(text, return_tensors="pt").shape

torch.Size([1, 10])

In [11]:
text ="""
translate English to German: That is good.
"""

generate_text(text, tokenizer, model, min_length=100, max_length=150)

[1mInput Text:[0m
 translate English to German: That is good.

[1mGenerated texts:
[0m
[1mGenerated text 0:[0m
Transtranslate English to German: That is good. Thanks for the help.Ad
vertisementsadvertisementadvertisementAbstractBackground:Introduction:
Introduction: Introduction.Introduction:Background:Background.Descript
ion:Description:Introduction.Introduction.Description.Translate
English and German.Background.Transparent English to
English.Background:Description.English:English: That's English.English
to English: That wouldtranslateEnglish to German-English to Spanish:
That.translate to English to Japanese:That is good, right?

[1mGenerated text 1:[0m
Transtranslate English to German: That is good. Thanks for the help.Ad
vertisementsadvertisementadvertisementAbstractBackground:Introduction:
Introduction: Introduction.Introduction:Background:Background.Descript
ion:Description:Introduction.Introduction.Description.Translate
English and German.Background.Transparent English 