In [0]:
!pip install pytorch_pretrained_bert

Collecting pytorch_pretrained_bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |████████████████████████████████| 133kB 3.4MB/s 
Installing collected packages: pytorch-pretrained-bert
Successfully installed pytorch-pretrained-bert-0.6.2


In [0]:
!pip install pronouncing

Collecting pronouncing
  Downloading https://files.pythonhosted.org/packages/7f/c6/9dc74a3ddca71c492e224116b6654592bfe5717b4a78582e4d9c3345d153/pronouncing-0.2.0.tar.gz
Collecting cmudict>=0.4.0 (from pronouncing)
[?25l  Downloading https://files.pythonhosted.org/packages/42/bc/606843d7cfe4d82f5a21fc46d1ae8e364ac20c57e68d1ec4190bce4f2734/cmudict-0.4.2-py2.py3-none-any.whl (938kB)
[K     |████████████████████████████████| 942kB 5.1MB/s 
[?25hBuilding wheels for collected packages: pronouncing
  Building wheel for pronouncing (setup.py) ... [?25l[?25hdone
  Stored in directory: /root/.cache/pip/wheels/81/fd/e8/fb1a226f707c7e20dbed4c43f81b819d279ffd3b0e2f06ee13
Successfully built pronouncing
Installing collected packages: cmudict, pronouncing
Successfully installed cmudict-0.4.2 pronouncing-0.2.0


In [0]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM, BertForNextSentencePrediction
import pronouncing
from itertools import chain
import string
import csv
import pdb

In [10]:
# mount Google Drive root
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
""" Try to generate from BERT """

def preprocess(tokens, tokenizer, device):
    """ Preprocess the lyrics by tokenizing and converting to tensor """
    
    tok_ids = tokenizer.convert_tokens_to_ids(tokens)
    tok_tensor = torch.tensor([tok_ids])
    tok_tensor = tok_tensor.to(device)
    return tok_tensor

  
def get_seed_sent(toks1, toks2, tokenizer):
    """ Get initial sentence to decode from, possible with masks """

    mask_ids = []

    # get total lyric tokens and [MASK] indices
    toks = toks1 + toks2
    for i, tok in enumerate(toks):
        if tok == "[MASK]":
            mask_ids.append(i)
            
    # get lyric segments
    seg = [0] * len(toks1) + [1] * len(toks2)
    
    # convert segments to tensors
    seg_tensor = torch.tensor([seg])
    
    return toks, seg_tensor, mask_ids

  
def load_masked_lang_model(version):
    """ Load BERT MLM """
    model = BertForMaskedLM.from_pretrained(version)
    model.eval()
    return model


def load_next_sent_pred_model(version):
    """ Load BERT next sentence prediction model """
    model = BertForNextSentencePrediction.from_pretrained(version)
    model.eval()
    return model
  
  
def predict(model, tokenizer, tok_tensor, seg_tensor, how_select="argmax"):
    """ Get model predictions and convert back to tokens """
    preds = model(tok_tensor, seg_tensor)
    
    # select random if "sample"
    if how_select == "sample":
        dist = Categorical(logits=F.log_softmax(preds[0], dim=-1))
        pred_idxs = dist.sample().tolist()
        
    # select top-3 if "topk"
    elif how_select == "topk":
        kth_vals, kth_idx = F.log_softmax(preds[0], dim=-1).topk(3, dim=-1)
        dist = Categorical(logits=kth_vals)
        pred_idxs = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1).tolist()
        
    # select best possible if "argmax"
    elif how_select == "argmax":
        pred_idxs = preds.argmax(dim=-1).tolist()[0]
        
    # if none of the above, raise error
    else:
        raise NotImplementedError("Prediction procedure %s not found!" % how_select)
    
    # return predicted [MASK] tags
    pred_toks = tokenizer.convert_ids_to_tokens(pred_idxs)
    return pred_toks

  
def masked_decoding(toks, device, seg_tensor, masks, model, tokenizer, selection_strategy):
    """ Decode from model by replacing masks """
    for step_n, mask_id in enumerate(masks):
        tok_tensor = preprocess(toks, tokenizer, device)
        pred_toks = predict(model, tokenizer, tok_tensor, seg_tensor, selection_strategy)
        toks[mask_id] = pred_toks[mask_id]
    return toks

  
def best_follows(text1_tokens, text2_tokens, tokenizer, model, device, k=1):
    """ Return k best next sentence predictions """
    
    # get seed lyric tokens and segment and attention lists
    text1_token_ids = tokenizer.convert_tokens_to_ids(text1_tokens)
    text1_seg = [0] * len(text1_token_ids)
    text1_attention = [1] * len(text1_token_ids)
    
    tok_ids = []
    tok_segs = []
    tok_attentions = []
    
    # get target lyric tokens and segment and attention lists
    text2_token_ids = []
    for text2_token in text2_tokens:
        text2_token_ids.append(tokenizer.convert_tokens_to_ids(text2_token))
    
    max_text2_length = max(len(text2_token_id) for text2_token_id in text2_token_ids)
    
    # get total lyric tokens and segment and attention lists
    for text2_token_id in text2_token_ids:
        padding_size = max_text2_length - len(text2_token_id)
        padded_text2_id = text2_token_id + [0] * padding_size
        padded_text2_seg = [1] * max_text2_length
        padded_text2_attention = [1] * len(text2_token_id) + [0] * padding_size
        
        tok_ids.append(text1_token_ids + padded_text2_id)
        tok_segs.append(text1_seg + padded_text2_seg)
        tok_attentions.append(text1_attention + padded_text2_attention)
    
    # convert tokens and segment and attention lists to tensors
    tok_ids_tensor = torch.LongTensor(tok_ids)
    tok_segs_tensor = torch.LongTensor(tok_segs)
    tok_attention_tensor = torch.LongTensor(tok_attentions)
    
    # transport to device
    tok_ids_tensor = tok_ids_tensor.to(device)
    tok_segs_tensor = tok_segs_tensor.to(device)
    tok_attention_tensor = tok_attention_tensor.to(device)

    # get is next or not next predictions
    seq_relationship_logits = model(tok_ids_tensor, tok_segs_tensor, tok_attention_tensor)
    
    # get the top-3/top-1 predictions and return the respective tokens
    _, idxs = torch.topk(seq_relationship_logits[:,0], k)
    
    return [text2_tokens[i] for i in idxs.tolist()]

 

def totalTokens(five):
    tokens = 0
    for line in five:
      tokens += len(line.split())
      if(tokens > 500):
        return False
      
    return True
                
@torch.no_grad()  
def main():
    
    # set device : use CUDA backend if GPU available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # get BERT tokenizer and pre-trained models
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    masked_lang_model = load_masked_lang_model('bert-base-uncased')
    next_sent_pred_model = load_next_sent_pred_model('bert-base-uncased')
    
    # transport model to device
    masked_lang_model = masked_lang_model.to(device)
    next_sent_pred_model = next_sent_pred_model.to(device)
    
    # get data
    five_pairs = []
    with open("gdrive/My Drive/courseWorks/Lyrics/data/test_rock.csv", encoding='utf8') as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            lines = row['lyrics'].split('\n')
            for i in range(0,len(lines)-4,5):
                current_five = [lines[i].strip(), lines[i + 1].strip(), lines[i + 2].strip(), lines[i + 3].strip(), lines[i + 4].strip()]
                
                five_pairs.append(current_five)
    cnt = 0
    
    # predictions for each pair of 5 lyrics
    for five_list in five_pairs:
        
        """ Pre-process texts to get seed text and target texts """
        
        text = ("\n").join(five_list)
        
        # put [CLS] and [SEP] tags for lyric 'beginning' and 'end' annotation
        text = text.replace("\n"," [SEP] ")
        text = "[CLS] " + text
        
        cont = False
        # add/keep punctuation at end of every other lyric
        text_split = text.split(" [SEP] ")
        for i in range(len(text_split)):
            if i%2 != 0:
              try:
                if text_split[i][-1] not in string.punctuation:
                    text_split[i] = text_split[i] + "."
              except IndexError:
                cont = True
#                 pdb.set_trace()
        
        if cont:
          continue
        # seed length
        text1_len = 3
        
        # get seed text tokens
        text1 = (" [SEP] ").join(text_split[:text1_len]) + " [SEP]"
        toks1 = tokenizer.tokenize(text1)

        text2s = text_split[text1_len:]
        text2s_copy = text2s.copy()
        
        # mask desired target lyric with [MASK] and [RHYME] tags
        for i in range(len(text2s)):
            if i%2 == 0:
                text2s[i] = ' '.join(text2s[i].split()[:len(text2s[i].split())-1]) + " [MASK] " + ". [SEP] "
                if text2s_copy[i][-1] not in string.punctuation:
                  text2s_copy[i] = text2s_copy[i] + ". [SEP] "
                else:
                  text2s_copy[i] = text2s_copy[i] + " [SEP] "
            else:
                text2s[i] = text2s[i] + " [SEP] "
                text2s_copy[i] = text2s_copy[i] + " [SEP] "

        text2 = (" ").join(text2s)
        text2_copy = (" ").join(text2s_copy)
        toks2_copy = tokenizer.tokenize(text2_copy)
        

        
        outputs = []
        
            
        # get target lyric tokens
        toks2 = tokenizer.tokenize(text2)

        # get lyric tokens, segment tensors and mask indices
        toks, seg_tensor, mask_ids = get_seed_sent(toks1, toks2, tokenizer)
#             print("toks # ", len(toks))
        if(len(toks) > 512):
#           print("toks # ", len(toks))
          continue
        seg_tensor = seg_tensor.to(device)

        # get [MASK] predictions
        pred_toks = masked_decoding(toks, device, seg_tensor, mask_ids, masked_lang_model, tokenizer, "argmax")

        # get predicted lyric
        outputs.append(pred_toks[len(toks1):])
        
#         get top-3/top-1 best sequence as predicted by the BERT next sentence prediction model
        
        outz = best_follows(toks1, outputs, tokenizer, next_sent_pred_model, device, k=1)

#         with open("gdrive/My Drive/courseWorks/Lyrics/preds/val_pred_file.txt", "a") as file:
#             file.write("\n Top-"+str(len(outz))+" predictions: "+str(outz)+"\n")
        
        # print the best possible lyric
        toks_pred = toks1 + outz[0]
        
        with open("gdrive/My Drive/courseWorks/Lyrics/preds/test_pred_file.txt", "a") as file:
            file.write("\n"+(" ").join(toks_pred)+"\n")
            
        toks_gold = toks1 + toks2_copy
        with open("gdrive/My Drive/courseWorks/Lyrics/preds/test_gold_file.txt", "a") as file:
            file.write("\n"+(" ").join(toks_gold) +"\n")
        
        # clear cache
        torch.cuda.empty_cache()
        
        cnt += 1
        
        if cnt%1000 == 0:
            
            print("Processed %s 5-pairs" %cnt)

if __name__ == '__main__':
    main()

100%|██████████| 231508/231508 [00:00<00:00, 2516841.14B/s]
100%|██████████| 407873900/407873900 [00:07<00:00, 58231277.84B/s]


Processed 1000 5-pairs
Processed 2000 5-pairs
Processed 3000 5-pairs
Processed 4000 5-pairs
Processed 5000 5-pairs
Processed 6000 5-pairs
Processed 7000 5-pairs
Processed 8000 5-pairs
Processed 9000 5-pairs
Processed 10000 5-pairs
Processed 11000 5-pairs
Processed 12000 5-pairs
Processed 13000 5-pairs
Processed 14000 5-pairs
Processed 15000 5-pairs
Processed 16000 5-pairs
Processed 17000 5-pairs
Processed 18000 5-pairs
Processed 19000 5-pairs
Processed 20000 5-pairs
Processed 21000 5-pairs
Processed 22000 5-pairs
Processed 23000 5-pairs
Processed 24000 5-pairs
Processed 25000 5-pairs
Processed 26000 5-pairs
Processed 27000 5-pairs
Processed 28000 5-pairs
Processed 29000 5-pairs
Processed 30000 5-pairs
Processed 31000 5-pairs
Processed 32000 5-pairs
Processed 33000 5-pairs
Processed 34000 5-pairs
Processed 35000 5-pairs
Processed 36000 5-pairs
Processed 37000 5-pairs
Processed 38000 5-pairs
Processed 39000 5-pairs
Processed 40000 5-pairs
Processed 41000 5-pairs
Processed 42000 5-pairs
P

In [15]:
from nltk.translate.bleu_score import sentence_bleu

def loadData(name):
    data = []
    with open(name) as file:
        for line in file:
            if(line == '\n'):
              continue
#             pdb.set_trace()
            data.append(line.split(" [SEP] ")[-2])

    return data

def bleuScore(gold, pred):
    cumulativeBlue, totalSentences = 0, len(gold)

    for i in range(len(gold)):
       
        reference = [gold[i].split(' ')]
        candidate = pred[i].split(' ') 
        cumulativeBlue += sentence_bleu(reference, candidate, weights=(.334, 0.333, 0.333, 0))

    return cumulativeBlue / totalSentences  

def accuracy(gold, pred):
    num_correct, num_total = 0, 0
    for i in range(len(gold)):
        if gold[i] == pred[i]:
            num_correct += 1
        num_total += 1

    accuracy = num_correct / num_total

    return accuracy
  

if __name__ == '__main__':
  gold = loadData("gdrive/My Drive/courseWorks/Lyrics/preds/test_gold_file.txt")
  pred = loadData("gdrive/My Drive/courseWorks/Lyrics/preds/test_pred_file.txt")
  print(f'Accuracy: {accuracy(gold, pred):.2f}')
  print(f'BLEU score: {bleuScore(gold, pred):.2f}')

Accuracy: 0.31


Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


BLEU score: 0.80
