# Notebook for text generation

In [1]:
from transformers import BertTokenizer, GPT2Tokenizer,  GPT2LMHeadModel, GPT2Tokenizer, BertForMaskedLM
import scipy
import pandas as pd
import numpy as np
import os
import torch
import glob
from BERT.tokenizer import tokenize

### Functions

In [2]:
def encode(line, t):
    encoded_dict = t.encode_plus(line, 
                   add_special_tokens=True,
                   max_length = 128,
                   return_attention_mask = True,
                   return_tensors = 'pt')
    return encoded_dict

def decode(distribution, t):
    s = torch.nn.Softmax(dim=0)
    proba = s(distribution.detach()).numpy()
    token = np.random.choice(np.arange(len(proba)), 1, p=proba)
    N_max = [t.decode(item) for item in [np.random.choice(np.arange(len(proba)), 10, p=proba)]]
    word = t.decode(token)
    return token, word, N_max

In [3]:
def next_input(last_input, new_token, bos_token=False, eos_token=False):
    if bos_token:
        if eos_token:
            result = torch.cat([last_input[:, 0].unsqueeze(0), last_input[:,2:-1], torch.ones((1,1), dtype=torch.long)* (int(new_token[0])),last_input[:, -1].unsqueeze(0)], dim=1)
        else:
            result = torch.cat([last_input[:, 0].unsqueeze(0), last_input[:,2:], torch.ones((1,1), dtype=torch.long)* (int(new_token[0]))], dim=1)
    elif eos_token:
        result = torch.cat([last_input[:,1:-1], torch.ones((1,1), dtype=torch.long)* (int(new_token[0])),last_input[:, -1].unsqueeze(0)], dim=1)
    else:
        result = torch.cat([last_input[:,1:], torch.ones((1,1), dtype=torch.long)* (int(new_token[0]))], dim=1)
    return result

### Model instanciation

In [4]:
#t_bert = BertTokenizer.from_pretrained('bert-base-cased')

In [5]:
t_gpt2_base = GPT2Tokenizer.from_pretrained('gpt2')

In [6]:
#t_gpt2_medium = GPT2Tokenizer.from_pretrained('gpt2-medium')

In [7]:
#model_bert = BertForMaskedLM.from_pretrained('bert-base-cased')

In [8]:
model_gpt2_base = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=t_gpt2_base.eos_token_id)

In [9]:
#model_gpt2_medium = GPT2LMHeadModel.from_pretrained('gpt2-medium', pad_token_id=t_gpt2_medium.eos_token_id)

### First token

In [128]:
prompt = ' Once , when I was six years old, I saw a magnificent picture in a book about the primeval forest called ‘Real-life Stories.’ It showed a boa constrictor swallowing a wild animal. Here is a copy of the drawing.'
language = 'english'


In [129]:
prompt_processed = ' '.join(tokenize(prompt, language))

100%|██████████| 3/3 [00:00<00:00, 9293.14it/s]

Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.





In [130]:
prompt_processed

' Once , when I was six years old , I saw a magnificent picture in a book about the primeval forest called ‘ Real - life Stories . ’  It showed a boa constrictor swallowing a wild animal .  Here is a copy of the drawing .'

In [13]:
nb_token_to_generate = 150

### BERT text  generation 

In [16]:
#input_ids

In [17]:
#encoded_dict = encode(prompt_processed, t_bert)
#input_ids = encoded_dict['input_ids']

In [18]:
#greedy_output = model_bert.generate(input_ids, max_length=150, num_beams=5, early_stopping=True, no_repeat_ngram_size=2)
#print(t_bert.decode(greedy_output[0], skip_special_tokens=True))

In [19]:
#result = ''

In [20]:
#encoded_dict = encode(prompt_processed, t_bert)
#input_ids = encoded_dict['input_ids']
#attention_mask = encoded_dict['attention_mask']
#token_type_ids = encoded_dict['token_type_ids']
#for _ in range(nb_token_to_generate):
#    out = model_bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
#    token, word, N_max = decode(out[0].squeeze()[-1], t_bert)
#    print(N_lmax)
#    result += ' ' + word
#    
#    input_ids = next_input(input_ids, token, bos_token=True, eos_token=True)

In [21]:
#result

### GPT2-base text generation

In [22]:
result = ''

In [153]:
encoded_dict = encode(prompt_processed, t_gpt2_base)
input_ids = encoded_dict['input_ids']

In [154]:
t_gpt2_base.decode(torch.Tensor([373])) # 198,   198,    40,   373, 257,  1310,  1643]]

' was'

In [156]:
greedy_output.shape

torch.Size([1, 150])

In [155]:
greedy_output = model_gpt2_base.generate(input_ids, max_length=150, ) #num_beams=5, early_stopping=True, no_repeat_ngram_size=2
print(t_gpt2_base.decode(greedy_output[0], skip_special_tokens=True))

input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764]])
input_ids after preparation:  <class 'NoneType'> tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1]]) True
tokens_to_add:  tensor([220])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  12

tokens_to_add:  tensor([262])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262]])
input_ids after preparation:  <class 'tuple'> tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) True
tokens_to_add:  tensor([4286])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  146

tokens_to_add:  tensor([314])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262,  4286,    13,   220,
           198,   198,    40,   373,   257,  1310,  2933,   618,   314]])
input_ids after preparation:  <class 'tuple'> tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1,

tokens_to_add:  tensor([1310])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262,  4286,    13,   220,
           198,   198,    40,   373,   257,  1310,  2933,   618,   314,  2497,
           262,  4286,    13,   220,   198,   198,    40,   373,   257,  1310]])
input_ids after preparation:  <class 'tuple'> tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1,

tokens_to_add:  tensor([198])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262,  4286,    13,   220,
           198,   198,    40,   373,   257,  1310,  2933,   618,   314,  2497,
           262,  4286,    13,   220,   198,   198,    40,   373,   257,  1310,
          2933,   618,   314,  2497,   262,  4286,    13,   220,   198,   198]])
input_ids after preparation:  <class 'tuple'> tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1,

tokens_to_add:  tensor([4286])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262,  4286,    13,   220,
           198,   198,    40,   373,   257,  1310,  2933,   618,   314,  2497,
           262,  4286,    13,   220,   198,   198,    40,   373,   257,  1310,
          2933,   618,   314,  2497,   262,  4286,    13,   220,   198,   198,
            40,   373,   257,  1310,  2933,   618,   314,  2497,   262,  4286]])
input_ids after preparation:  <class 'tuple'> tensor([[1, 1, 1, 1, 

tokens_to_add:  tensor([618])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262,  4286,    13,   220,
           198,   198,    40,   373,   257,  1310,  2933,   618,   314,  2497,
           262,  4286,    13,   220,   198,   198,    40,   373,   257,  1310,
          2933,   618,   314,  2497,   262,  4286,    13,   220,   198,   198,
            40,   373,   257,  1310,  2933,   618,   314,  2497,   262,  4286,
            13,   220,   198,   198,    40,   373,   257,  1310,  2933

tokens_to_add:  tensor([40])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262,  4286,    13,   220,
           198,   198,    40,   373,   257,  1310,  2933,   618,   314,  2497,
           262,  4286,    13,   220,   198,   198,    40,   373,   257,  1310,
          2933,   618,   314,  2497,   262,  4286,    13,   220,   198,   198,
            40,   373,   257,  1310,  2933,   618,   314,  2497,   262,  4286,
            13,   220,   198,   198,    40,   373,   257,  1310,  2933,

tokens_to_add:  tensor([262])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262,  4286,    13,   220,
           198,   198,    40,   373,   257,  1310,  2933,   618,   314,  2497,
           262,  4286,    13,   220,   198,   198,    40,   373,   257,  1310,
          2933,   618,   314,  2497,   262,  4286,    13,   220,   198,   198,
            40,   373,   257,  1310,  2933,   618,   314,  2497,   262,  4286,
            13,   220,   198,   198,    40,   373,   257,  1310,  2933

tokens_to_add:  tensor([1310])
input_ids before preparation:  tensor([[ 4874,   837,   618,   314,   373,  2237,   812,  1468,   837,   314,
          2497,   257, 25023,  4286,   287,   257,  1492,   546,   262,  6994,
          2100,  8222,  1444,   564,   246,  6416,   532,  1204, 18152,   764,
           564,   247,   220,   632,  3751,   257,  1489,    64,  1500,  2012,
           273, 45590,   257,  4295,  5044,   764,   220,  3423,   318,   257,
          4866,   286,   262,  8263,   764,   220,   198,   198,    40,   373,
           257,  1310,  2933,   618,   314,  2497,   262,  4286,    13,   220,
           198,   198,    40,   373,   257,  1310,  2933,   618,   314,  2497,
           262,  4286,    13,   220,   198,   198,    40,   373,   257,  1310,
          2933,   618,   314,  2497,   262,  4286,    13,   220,   198,   198,
            40,   373,   257,  1310,  2933,   618,   314,  2497,   262,  4286,
            13,   220,   198,   198,    40,   373,   257,  1310,  293

In [136]:
def encode(line, t):
    t.pad_token = t.eos_token
    encoded_dict = t.encode_plus(line, 
                   add_special_tokens=True,
                   max_length = 128,
                   return_attention_mask = True,
                   pad_to_max_length = False,
                   return_tensors = 'pt',
                   add_prefix_space=False)
    return encoded_dict

In [135]:
t_gpt2_base.prepare_for_model(input_ids[mask], pad_to_max_length=True, max_length=120)

IndexError: The shape of the mask [1, 278] at index 1 does not match the shape of the indexed tensor [1, 116] at index 1

In [158]:
def next_input(input_ids, out, bos_token=False, eos_token=False):
    next_token_logits = out[0][:, -1, :]
    tokens_to_add = torch.argmax(next_token_logits, dim=-1)
    result = torch.cat([input_ids, tokens_to_add.unsqueeze(0)], dim=-1)
    return result

In [159]:
encoded_dict = encode(prompt_processed, t_gpt2_base)
input_ids = encoded_dict['input_ids']
attention_mask = encoded_dict['attention_mask']
token_type_ids = encoded_dict['token_type_ids']
for _ in range(nb_token_to_generate):
    out = model_gpt2_base(input_ids=input_ids, attention_mask=attention_mask, past=None) #, token_type_ids=token_type_ids
    #token, word, N_lmax = decode(out[0].squeeze()[-1], t_gpt2_base)
    #print(N_lmax)
    #result += ' ' + word
    input_ids = next_input(input_ids, out)
    attention_mask = torch.ones(input_ids.shape[-1], dtype=torch.long).unsqueeze(0)
    #token_type_ids = torch.zeros(input_ids.shape[-1], dtype=torch.long).unsqueeze(0)
    #attention_mask = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

In [161]:
print(t_gpt2_base.decode(input_ids[0]))

 Once, when I was six years old, I saw a magnificent picture in a book about the primeval forest called ‘ Real - life Stories. ’  It showed a boa constrictor swallowing a wild animal.  Here is a copy of the drawing. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I saw the picture. 

I was a little boy when I


In [None]:
result

### GPT2-medium text generation

In [None]:
result = ''

In [None]:
encoded_dict = encode(prompt_processed, t_gpt2_medium)
input_ids = encoded_dict['input_ids']
attention_mask = encoded_dict['attention_mask']
token_type_ids = encoded_dict['token_type_ids']
for _ in range(nb_token_to_generate):
    out = model_gpt2_medium(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    token, word = decode(out[0].squeeze()[-1], t_gpt2_medium)
    result += ' ' + word
    input_ids = next_input(input_ids, token)

In [None]:
result