# BART Learns To Rap
Notebook finetuning the BART model to create hip hop lyrics in a sequence to sequence fashion. Much of this notebook borrows code from Hugging Face's finetune.py, modeling_bart.py and the utils.py in the git repository for seq2seq models. The lyrics for fine tuning this model were taken from this github repo: http://www.github.com/fpaupier/RapLyrics-Scraper

This notebook employs a noising mechanism on the source sentences before they are added here. The noising function is included here as text is noised during generation. This adds additional stochasticity to the model leading to (I think) greater variation in the kinds of lyrics that the model generates.

# Firing up Google Drive
Load up your google drive for loading the lyrics for training and for saving model weights

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=False)


Mounted at /content/gdrive


In [None]:
# This run uses Pytorch Lightening to finetune the model
!pip install -q pytorch-lightning
!pip install -q transformers
!pip install -q tdqm

[K     |████████████████████████████████| 563kB 8.7MB/s 
[K     |████████████████████████████████| 276kB 54.1MB/s 
[K     |████████████████████████████████| 92kB 12.6MB/s 
[K     |████████████████████████████████| 829kB 57.4MB/s 
[?25h  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
  Building wheel for future (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 1.3MB 8.3MB/s 
[K     |████████████████████████████████| 2.9MB 33.4MB/s 
[K     |████████████████████████████████| 890kB 57.2MB/s 
[K     |████████████████████████████████| 1.1MB 35.8MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Building wheel for tdqm (setup.py) ... [?25l[?25hdone


In [None]:
# imports
import transformers
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset
import pandas as pd
import numpy as np

import torch.nn.functional as F
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

from tqdm import tqdm, trange
import math
import random
import re
import argparse

# ****REMEMBER CHANGE STUFF HERE set up variables for ppl and gen

In [None]:
root_dir = "/content/gdrive/MyDrive/"
base_dir = root_dir + 'NLP CS395T/Artist-ic Endeavor'

# ****REMEMBER: Change this to the write files
train_data_file_name = 'lyrics_tagged_ablation_7_noised.csv'
model_file_name = 'tacc_checkpoints/bart_a7/epoch=14.ckpt'
model_name = 'bart_a7'
persona_file = base_dir + "/Persona Ablations/persona_ablation_7.csv"

train_data_file = base_dir + "/training_csvs/" + train_data_file_name
model_file = base_dir + "/models/" + model_file_name

# Pytorch Lightning for running the training
The below code uses Pytorch Lightning for training the model, which is explained very well (and simply) at https://pytorch-lightning.readthedocs.io/en/latest/. Very briefly, most of th usual methods one would set up for a Pytorch class are setup in a pl.LightningModule class. This then goes on to automate a bunch of the training for example updating the optimizer, clearing gradients etc.

In [None]:
class LitModel(pl.LightningModule):
  # Instantiate the model
  def __init__(self, learning_rate, tokenizer, model, hparams):
    super().__init__()
    self.tokenizer = tokenizer
    self.model = model
    self.learning_rate = learning_rate
    # self.freeze_encoder = freeze_encoder
    # self.freeze_embeds_ = freeze_embeds
    self.hparams = hparams

    if self.hparams.freeze_encoder:
      freeze_params(self.model.get_encoder())

    if self.hparams.freeze_embeds:
      self.freeze_embeds()
  
  def freeze_embeds(self):
    ''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
    freeze_params(self.model.model.shared)
    for d in [self.model.model.encoder, self.model.model.decoder]:
      freeze_params(d.embed_positions)
      freeze_params(d.embed_tokens)

  # Do a forward pass through the model
  def forward(self, input_ids, **kwargs):
    return self.model(input_ids, **kwargs)
  
  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

  def training_step(self, batch, batch_idx):
    # Load the data into variables
    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]
    # Shift the decoder tokens right (but NOT the tgt_ids)
    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)

    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]
    # Create the loss function
    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    # Calculate the loss on the un-shifted tokens
    loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss':loss}

  def validation_step(self, batch, batch_idx):

    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    
    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]

    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss': val_loss}
  
  # Method that generates text using the BartForConditionalGeneration's generate() method
  def generate_text(self, text, eval_beams, early_stopping = True, max_len = 15):
    ''' Function to generate text '''
    # skipping S <mask> <unk> <pad> </s> <s>
    skip_gen_tensor = [[0, 1437, 104, 2],
                        [0, 50264, 2],
                        [0, 1437, 1, 2],
                        [0, 1437, 3, 2],
                        [0, 1437, 2, 2],
                        [0, 1437, 0, 2]]
    generated_ids = self.model.generate(
        text["input_ids"].cuda(),
        use_cache=True,
        do_sample=True,
        top_p=0.9, 
        top_k=0,
        decoder_start_token_id = self.tokenizer.bos_token_id,
        eos_token_id = self.tokenizer.eos_token_id,
        num_beams= eval_beams,
        max_length = max_len,
        min_length = 0,
        early_stopping = early_stopping,
        bad_words_ids = skip_gen_tensor,
    )
    return generated_ids

def freeze_params(model):
  ''' Function that takes a model as input (or part of a model) and freezes the layers for faster training
      adapted from finetune.py '''
  for layer in model.parameters():
    layer.requires_grade = False


In [None]:
# Create a dataloading module as per the PyTorch Lightning Docs
class SummaryDataModule(pl.LightningDataModule):
  def __init__(self, tokenizer, data_file, batch_size, num_examples = 775959):
    super().__init__()
    self.tokenizer = tokenizer
    self.data_file = data_file
    self.batch_size = batch_size
    self.num_examples = num_examples
    self.data = pd.read_csv(self.data_file)[:self.num_examples]
    self.train, self.validate, self.test = np.split(self.data, [621405, 621405 + 75795])
    self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target'])
    self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target'])
    self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target'])
  
  # Loads and splits the data into training, validation and test sets with a 60/20/20 split
  def prepare_data(self):
      pass

  # encode the sentences using the tokenizer  
  def setup(self, stage):
      pass

  # Load the training, validation and test sets in Pytorch Dataset objects
  def train_dataloader(self):
    dataset = TensorDataset(self.train['input_ids'], self.train['attention_mask'], self.train['labels'])                          
    train_data = DataLoader(dataset, sampler = RandomSampler(dataset), batch_size = self.batch_size)
    return train_data

  def val_dataloader(self):
    dataset = TensorDataset(self.validate['input_ids'], self.validate['attention_mask'], self.validate['labels']) 
    val_data = DataLoader(dataset, batch_size = self.batch_size)                       
    return val_data

  def test_dataloader(self):
    dataset = TensorDataset(self.test['input_ids'], self.test['attention_mask'], self.test['labels']) 
    test_data = DataLoader(dataset, batch_size = self.batch_size)                   
    return test_data



In [None]:

# Create the hparams dictionary to pass in the model
# I realise that this isn't really how this is meant to be used, but having this here reminds me that I can edit it when I need
hparams = argparse.Namespace()

hparams.freeze_encoder = False
hparams.freeze_embeds = False
hparams.eval_beams = 4

In [None]:
def shift_tokens_right(input_ids, pad_token_id):
  """ Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
      This is taken directly from modeling_bart.py
  """
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=160, pad_to_max_length=True, return_tensors="pt"):
  ''' Function that tokenizes a sentence 
      Args: tokenizer - the BART tokenizer; source and target sentences are the source and target sentences
      Returns: Dictionary with keys: input_ids, attention_mask, target_ids
  '''

  input_ids = []
  attention_masks = []
  target_ids = []
  tokenized_sentences = {}
  i = 0
  for sentence in source_sentences:
    i += 1
    sentence_tokens = sentence.split(' ')
    tok_ids = tokenizer.convert_tokens_to_ids(sentence_tokens)
    encoded_dict = tokenizer.prepare_for_model(
        tok_ids, 
        max_length=max_length,
        add_special_tokens=False,
        padding="max_length" if pad_to_max_length else None,
        truncation=True,
        return_tensors=return_tensors,
        return_attention_mask=True,
        add_prefix_space=True,
        prepend_batch_axis = True,
    )
    
    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

  input_ids = torch.cat(input_ids, dim = 0)
  attention_masks = torch.cat(attention_masks, dim = 0)

  for sentence in target_sentences:
    sentence_tokens = sentence.split(' ')
    tok_ids = tokenizer.convert_tokens_to_ids(sentence_tokens)
    encoded_dict = tokenizer.prepare_for_model(
        tok_ids, 
        max_length=max_length,
        add_special_tokens=False,
        padding="max_length" if pad_to_max_length else None,
        truncation=True,
        return_tensors=return_tensors,
        return_attention_mask=True,
        add_prefix_space=True,
        prepend_batch_axis = True,
    )
    # Shift the target ids to the right
    # shifted_target_ids = shift_tokens_right(encoded_dict['input_ids'], tokenizer.pad_token_id)
    target_ids.append(encoded_dict['input_ids'])

  target_ids = torch.cat(target_ids, dim = 0)

  batch = {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
      "labels": target_ids,
  }

  return batch


def noise_sentence(sentence_, percent_words, replacement_token = "<mask>"):
    '''
    Args: sentence - the sentence to noise
          percent_words - the percent of words to remove
    '''
    # Create a list item and copy
    sentence_ = sentence_.split(' ')
    newsent = []
    for word in sentence_:
    	if random.random() < percent_words:
    		if len(newsent) > 0:
    				if newsent[-1] != replacement_token:
    					newsent.append(replacement_token)
    		else:
    			newsent.append(word)
    	else:
    		newsent.append(word)
    return " ".join(newsent)
  

# Load BART
Here we load the model. I used "bart-base" because I had memory issues using "bart-large". "bart-base" appears to load without the use_cache argument, which by necessity must be turned to "False" for "bart-large".

In [None]:
# Load the model
import os 

from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig
myfile = open("{}/bpe_string_token_to_int.json".format(base_dir), "r")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base", add_prefix_space=True, bos_token = "S", eos_token = "L", mask_token = "<mask>")
print(len(tokenizer))
import json
toadd = list(json.load(myfile).keys())
print(len(toadd))
# #toadd.append("<mask>")
tokenizer.add_tokens(toadd)
print(len(tokenizer))

bart_model = BartForConditionalGeneration.from_pretrained(
    "facebook/bart-base")
bart_model.resize_token_embeddings(len(tokenizer))

50265
31181
76643


Embedding(76643, 768)

In [None]:
# Load the data into the model for training
summary_data = SummaryDataModule(tokenizer, train_data_file,
                                 batch_size = 16)

# Load the model from a pre-saved checkpoint; alternatively use the code below to start training from scratch
model = LitModel.load_from_checkpoint(model_file,
                                      learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

# model = LitModel(learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

# Training the model with Pytorch Lightning
The below code utilises Pytorch Lightning's fantastic Trainer module that helps to control the training process. After creating a ModelCheckpoint object, the other options are fed into the Trainer module. I found that my colab crashed when I didn't explicitly set progress_bar_refresh_rate to something and I found that setting it to 500 seemed to work just fine.

In [None]:

checkpoint = ModelCheckpoint(filepath=base_dir + 'checkpoint_files_2/')
trainer = pl.Trainer(gpus = 1,
                     max_epochs = 1,
                     min_epochs = 1,
                     auto_lr_find = False,
                     checkpoint_callback = checkpoint,
                     progress_bar_refresh_rate = 25)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
# Fit the instantiated model to the data
trainer.fit(model, summary_data)

In [None]:
# If you want to manually save a checkpoint, this works, although the model should automatically save (progressively better)
# checkpoints as it moves through the epochs
# trainer.save_checkpoint(base_dir + "checkpoint_files_2/8_ep_140k_simple_0210.ckpt")

# Getting BART to Rap
Now that we've trained BART for a few epochs on 140,000 lines of hip hop lyrics, there's a function below to get BART to generate lines of lyrics auto-regressively. Although the generate_lyrics() function allows one to use a number of previous lines to generate the next line, I generally found that this didn't improve the lyrics generated. This was initially setup like this for training BART on multiple lines of lyrics, but the results from this approach weren't promising - it made the lyrics generated very repetive and copied lyrics from the training set.

I found that adding some noise in the line(s) that the next line is conditioned on generally made the lyrics more interesting... it's the kind of thing that can keep one regenerating lyrics for ages based on a single seed line. 

# Perplexity

In [None]:

from pytorch_lightning.core.decorators import auto_move_data

def perplexity(model, test_set):
  model.eval()
  model.cuda()

  val_losses = np.array([])
  for batch_idx, batch in enumerate(tqdm(test_set)):
    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    
    # Run the model and get the logits
    outputs = model(src_ids.cuda(), attention_mask=src_mask.cuda(), decoder_input_ids=decoder_input_ids.cuda(), use_cache=False)
    lm_logits = outputs[0]

    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=model.tokenizer.pad_token_id)
    val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.cuda().view(-1))
    val_losses = np.append(val_losses, val_loss.cpu().detach().numpy())
  ppl = np.exp(np.mean(val_losses))
  print("Perplexity is:", ppl)

In [None]:
with torch.no_grad():
  perplexity(model.cpu(), summary_data.test_dataloader())

100%|██████████| 4923/4923 [24:39<00:00,  3.33it/s]

Perplexity is: 72.71177444593066





# Generate Verses




In [None]:
def noise_ctxt(ctxt, percent_words, replacement_id):
    '''
    Args: ctxt - the sentence to noise (in ids)
          percent_words - the percent of words to remove
          replacement_id - the id of the mask token
    '''
    # Create a list item and copy
    newctxt = []
    for id_ in ctxt:
    	if random.random() < percent_words:
    		if len(newctxt) > 0:
    				if newctxt[-1] != replacement_id:
    					newctxt.append(replacement_id)
    		else:
    			newctxt.append(id_)
    	else:
    		newctxt.append(id_)
    return newctxt

def load_personas(personas_path, remove_tokens=False):
  res = {}
  idx = 1
  max_len = 0
  with open(personas_path, 'r') as personas:
    for line in personas:
      if remove_tokens:
        if 'G' in words:
          start_token_remove = line.index('G')
          if 'A' in words:
            end_token_remove = line.index('A')
          elif 'Y' in words:
            end_token_remove = line.index('Y')
          else:
            end_token_remove = len(line)
          line = line[0:start_token_remove] + line[end_token_remove:]
      res[idx] = line.split()
      res[idx].extend(['W'])
      idx += 1
  return res

def load_personas_from_csv(filepath):
    persona_csv = pd.read_csv(filepath)
    personas_dict = {}
    for i in range(0,91):
        persona_id = str(i + 1)
        persona_string = persona_csv['persona'][i]
        persona_string = persona_string.split()
        persona_string.append('W')
        personas_dict[persona_id] = persona_string
    return personas_dict
    

def encode_sentence(sentence_tokens):
    tok_ids = tokenizer.convert_tokens_to_ids(sentence_tokens)
    encoded_dict = tokenizer.prepare_for_model(
        tok_ids, 
        add_special_tokens=False,
        return_tensors='pt',
        return_attention_mask=True,
        add_prefix_space=True,
        prepend_batch_axis = True,
    )
    return encoded_dict

def generate_lyrics(persona_ids, num_lines, model_, noise_percent = 0.25):
  ''' Function that generates lyrics based on previously generated lyrics 
      Args: seed_line - a line to start off the machine
            num_lines - the number of lines to generate
            model_ - the model used to generate the text
            multiple_lines - whether the model generates based on multiple previous lines or just the past line
            max_line_history - the maximum number of previous lines used in the current input
      Returns a list with num_lines of rap lines
  '''
  # Put the model on eval mode
  model_.eval()
  model.cuda()
  lyrics = []
  ctxt = []
  for i in range(num_lines):
    if len(lyrics) > 0:
      noised_ctxt = noise_ctxt(ctxt[-100:], noise_percent, model.tokenizer.mask_token_id)
      input = {
          'input_ids': torch.cat([persona_ids['input_ids'], torch.tensor(noised_ctxt).unsqueeze(0).long()], dim=-1)
      }
    else:
      input = persona_ids
    line = model.generate_text(input, eval_beams = 1)
    ctxt.extend(line.squeeze().tolist())
    lyrics.extend([model.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in line.squeeze().tolist()])
  return lyrics

import os
metric_dir = '{}/metrics/{}'.format(base_dir, model_name)

def generate_verses(model, personas):
  if not os.path.exists(metric_dir):
    os.makedirs(metric_dir)
  # ****REMEMBER: make sure you are starting from the right artist 
  for i in trange(len(personas), desc="Artist #"):
    persona = str(i+1)
    persona_tokens = personas[persona]
    persona_ids = encode_sentence(persona_tokens)
    artist_verses = []
    for _ in trange(50, desc="Verse #", leave=False):
      verse = generate_lyrics(persona_ids, 16, model, noise_percent=.25)
    #   print(verse)
      artist_verses.append(verse)
    with open("{}/verses_{}".format(metric_dir, persona),'w') as artist_verse_file:
        json.dump(artist_verses, artist_verse_file)

# Compile verses together
def compile_artist_verses():
  generated_verses = {}
  for i in trange(91):
    persona = str(i+1)
    with open("{}/verses_{}".format(metric_dir, persona)) as artist_verse_file:
      artists_verses = json.load(artist_verse_file)
    generated_verses[persona] = artists_verses
  with open("{}/{}_all_verses.json".format(metric_dir, model_name), 'w') as verses_file:
    json.dump(generated_verses, verses_file)

def load_id_personas():
    personas_dict = {}
    for i in range(91):
        artist_id = str(i + 1);
        personas_dict[artist_id] = [artist_id, 'W']
    return personas_dict

In [None]:
with torch.no_grad():
  # ****REMEMBER: Need to use the correct personas per model setup
  generate_verses(model, load_personas_from_csv(persona_file))
    # generate_verses(model, load_id_personas())

Artist #:   0%|          | 0/91 [00:00<?, ?it/s]
Verse #:   0%|          | 0/50 [00:00<?, ?it/s][A
Verse #:   2%|▏         | 1/50 [00:16<13:05, 16.03s/it][A
Verse #:   4%|▍         | 2/50 [00:18<09:28, 11.85s/it][A
Verse #:   6%|▌         | 3/50 [00:20<06:56,  8.86s/it][A
Verse #:   8%|▊         | 4/50 [00:21<05:06,  6.67s/it][A
Verse #:  10%|█         | 5/50 [00:23<03:55,  5.23s/it][A
Verse #:  12%|█▏        | 6/50 [00:25<03:08,  4.28s/it][A
Verse #:  14%|█▍        | 7/50 [00:27<02:31,  3.51s/it][A
Verse #:  16%|█▌        | 8/50 [00:28<01:58,  2.82s/it][A
Verse #:  18%|█▊        | 9/50 [00:30<01:44,  2.56s/it][A
Verse #:  20%|██        | 10/50 [00:32<01:34,  2.35s/it][A
Verse #:  22%|██▏       | 11/50 [00:34<01:26,  2.22s/it][A
Verse #:  24%|██▍       | 12/50 [00:35<01:18,  2.06s/it][A
Verse #:  26%|██▌       | 13/50 [00:37<01:12,  1.95s/it][A
Verse #:  28%|██▊       | 14/50 [00:39<01:10,  1.97s/it][A
Verse #:  30%|███       | 15/50 [00:41<01:07,  1.92s/it][A
Verse #: 

In [None]:
compile_artist_verses()

100%|██████████| 91/91 [00:00<00:00, 497.16it/s]


In [None]:
load_personas('{}/persona_tags_bpe.txt'.format(base_dir))



{1: ['N',
  '2',
  'chainz',
  'I',
  '1',
  'R',
  'ta@@',
  'u@@',
  'heed',
  'e@@',
  'pps',
  'C',
  'atlanta',
  'M',
  'tity',
  'boi',
  ',',
  'dren@@',
  'ch@@',
  'god',
  'G',
  'playaz',
  'circle',
  'A',
  'rap',
  'or',
  'go',
  'to',
  'the',
  'league',
  ',',
  'pretty',
  'girls',
  'like',
  'trap',
  'music',
  ',',
  'collegrove',
  ',',
  'b',
  'o',
  'a',
  't',
  's',
  'ii',
  'me',
  'time',
  ',',
  'based',
  'on',
  'a',
  't',
  'r',
  'u',
  'story',
  'Y',
  '1997',
  'W'],
 2: ['N',
  '21',
  'savage',
  'I',
  '2',
  'R',
  'she@@',
  'yaa',
  'bin',
  'abraham',
  'joseph',
  'C',
  'atlanta',
  'M',
  '21',
  'A',
  'savage',
  'mode',
  'ii',
  ',',
  'i',
  'am',
  'i',
  'was',
  ',',
  'without',
  ',',
  'issa',
  'album',
  ',',
  'savage',
  'mode',
  'Y',
  '20@@',
  '13',
  'W'],
 3: ['N',
  '2pac',
  'I',
  '3',
  'R',
  'tu@@',
  'pac',
  'a@@',
  'mar@@',
  'u',
  'shakur',
  'C',
  'oakland',
  'M',
  'makaveli',
  ',',
  'mc',
  'ne

In [None]:
skip_gen_tokens = ['S', '<mask>', '<pad>', '<unk>','</s>','<s>']
skip_gen_tensor = []
for token in skip_gen_tokens:
    skip_gen_tensor.append(tokenizer(token,add_prefix_space=True).input_ids)
skip_gen_tensor

[[0, 1437, 104, 2],
 [0, 50264, 2],
 [0, 1437, 1, 2],
 [0, 1437, 3, 2],
 [0, 1437, 2, 2],
 [0, 1437, 0, 2]]

In [None]:
skip_gen_tensor = [[0, 1437, 104, 2],
 [0, 50264, 2],
 [0, 1437, 1, 2],
 [0, 1437, 3, 2],
 [0, 1437, 2, 2],
 [0, 1437, 0, 2]]

In [None]:
a = [1,2,3,4,5]
a[:1]

[1]

In [None]:
personas = pd.read_csv(persona_file)
personas

Unnamed: 0,persona
0,N 2 chainz I 1 R ta@@ u@@ heed e@@ pps C atlan...
1,N 21 savage I 2 R she@@ yaa bin abraham joseph...
2,N 2pac I 3 R tu@@ pac a@@ mar@@ u shakur C oak...
3,N 50 cent I 4 R curtis james jackson iii C new...
4,N 6ix9ine I 5 R daniel her@@ nan@@ dez C new y...
...,...
86,N u god I 87 R lamont jody hawkins C new york ...
87,N vanilla ice I 88 R robert matthew van winkle...
88,N vinnie paz I 89 R vin@@ cen@@ zo lu@@ vin@@ ...
89,N wiz khalifa I 90 R cameron ji@@ bri@@ l tho@...


In [None]:
personas['persona'][0].split()

['N',
 '2',
 'chainz',
 'I',
 '1',
 'R',
 'ta@@',
 'u@@',
 'heed',
 'e@@',
 'pps',
 'C',
 'atlanta',
 'M',
 'tity',
 'boi',
 ',',
 'dren@@',
 'ch@@',
 'god',
 'G',
 'playaz',
 'circle',
 'Y',
 '1997']