In [1]:
!pip install pytorch_pretrained_bert
!pip install pronouncing



In [0]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-large-uncased')
model.eval()

sentence = '[CLS] Charles is a tailor [SEP] He is tall [SEP]'
# sentence = '[CLS] Charles is a tailor [SEP] Excavation is important [SEP]'
toks = tokenizer.tokenize(sentence)
tok_ids = tokenizer.convert_tokens_to_ids(toks)
tok_tensor = torch.LongTensor([tok_ids])
token_type_ids_tensor = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
seq_relationship_logits = model(tok_tensor, token_type_ids_tensor)
print(seq_relationship_logits)


In [0]:
import itertools
import torch
from pytorch_pretrained_bert import BertTokenizer, BertForNextSentencePrediction
import pronouncing
import re

def predict_next_sentence(sentenceA, sentenceBs, tokenizer, model, rhyme=False):
    seq_relationship_logits = get_next_sentence_logits(sentenceA, sentenceBs, tokenizer, model)
    if rhyme:
        for i, sentenceB in enumerate(sentenceBs):
            last_word_A = re.sub(r'[^\w]', '', sentenceA.split()[-1])
            last_word_B = re.sub(r'[^\w]', '', sentenceB.split()[-1])
            if last_word_A in pronouncing.rhymes(last_word_B) or \
                last_word_B in pronouncing.rhymes(last_word_A):
                print(f'{last_word_A} rhymes with {last_word_B}')
                seq_relationship_logits[i, 0] += 10
    return sentenceBs[seq_relationship_logits[:,0].argmax().tolist()]

def get_ids_types_attention_from_sentence_pair(sentenceA, sentenceB, pad_total_size, tokenizer):
    sentenceA_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentenceA))
    sentenceB_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentenceB))
    padding_size = pad_total_size - len(sentenceA_ids) - len(sentenceB_ids)
    sentence_ids = sentenceA_ids + sentenceB_ids + [0] * padding_size
    sentence_types = [0] * len(sentenceA_ids) + [1] * len(sentenceB_ids) + [0] * padding_size
    sentence_attention = [1] * (len(sentenceA_ids) + len(sentenceB_ids)) + [0] * padding_size
    return sentence_ids, sentence_types, sentence_attention

def reconstruct_song(lines, tokenizer, model):
    pad_total_size = 2 * max(len(line) for line in lines)
    for ordering in itertools.permutations(lines):
        pair_ids, pair_types, pair_attentions = [], [], []        
        for i in range(len(ordering) - 1):
            pair_id, pair_type, pair_attention = \
                get_ids_types_attention_from_sentence_pair(ordering[i], ordering[i + 1], pad_total_size, tokenizer)
            pair_ids.append(pair_id)
            pair_types.append(pair_type)
            pair_attentions.append(pair_attention)
        ids_tensor = torch.LongTensor(pair_ids)
        types_tensor = torch.LongTensor(pair_types)
        attention_tensor = torch.LongTensor(pair_attentions)
        seq_relationship_logits = model(ids_tensor, types_tensor, attention_tensor)
#         print(seq_relationship_logits)
#         print(sum(seq_relationship_logits[:, 0].tolist()))
#         print(ordering)
    
    
def get_next_sentence_logits(sentenceA, sentenceBs, tokenizer, model):
    sentenceA_toks = tokenizer.tokenize(sentenceA)
    sentenceA_ids = tokenizer.convert_tokens_to_ids(sentenceA_toks)
    sentenceA_types = [0] * len(sentenceA_ids)
    sentenceA_attention = [1] * len(sentenceA_ids)
    tok_ids = []
    tok_types = []
    tok_attention = []
    
    sentenceBs_ids = []
    for sentenceB in sentenceBs:
        sentenceB_toks = tokenizer.tokenize(sentenceB)
        sentenceB_ids = tokenizer.convert_tokens_to_ids(sentenceB_toks)
        sentenceBs_ids.append(sentenceB_ids)
        
    max_sentenceB_length = max(len(sentenceB_ids) for sentenceB_ids in sentenceBs_ids)
    for sentenceB_ids in sentenceBs_ids:
        padding_size = max_sentenceB_length - len(sentenceB_ids)
        padded_sentenceB_ids = sentenceB_ids + [0] * padding_size
        padded_sentenceB_types = [1] * max_sentenceB_length
        padded_sentenceB_attention = [1] * len(sentenceB_ids) + [0] * padding_size
        tok_ids.append(sentenceA_ids + padded_sentenceB_ids)
        tok_types.append(sentenceA_types + padded_sentenceB_types)
        tok_attention.append(sentenceA_attention + padded_sentenceB_attention)
    tok_ids_tensor = torch.LongTensor(tok_ids)
    tok_types_tensor = torch.LongTensor(tok_types)
    tok_attention_tensor = torch.LongTensor(tok_attention)
    seq_relationship_logits = model(tok_ids_tensor, tok_types_tensor, tok_attention_tensor)
    return seq_relationship_logits
    

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-large-uncased')
model.eval()

BertForNextSentencePrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (

In [0]:
# Demonstration of predict_next_sentence
predicted_sentence = predict_next_sentence('[CLS] Charles is a tailor [SEP]', ['He is green [SEP]', 'He is very tall [SEP]', 'Excavation is important [SEP]'], tokenizer, model)
print(predicted_sentence)

He is very tall [SEP]


In [0]:
import csv
import random
import argparse

num_to_generate = 10
num_options = 30

def generate_predictions(args):
    all_lines = []
    all_pairs = []
    with open(args.datafile, encoding='utf8') as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            lines = row['lyrics'].split('\n')
            for i in range(len(lines) - 1):
                all_pairs.append((lines[i], lines[i + 1]))
                all_lines.append(lines[i])
            all_lines.append(lines[len(lines) - 1])

    sampled_data_x = {}
    sampled_data_y = {}
    correct_pairs = random.sample(all_pairs, num_to_generate)
    for line1, line2 in correct_pairs:
        sampled_data_y[line1] = line2
        sampled_data_x[line1] = [line2]
        sampled_data_x[line1].extend(random.sample(all_lines, num_options - 1))
    
    
#     Batch make predictions to speed up runtime - crashes Colab for using too much RAM
#
#     all_sentence_ids, all_sentence_types, all_sentence_attentions = [], [], []
#     for i, (line1, line2s) in enumerate(sampled_data_x.items()):
#         sentenceA = line1
#         for sentenceB in line2s:
#             sentence_ids, sentence_types, sentence_attentions = \
#                 get_ids_types_attention_from_sentence_pair(sentenceA, sentenceB, 200, tokenizer)
#             all_sentence_ids.append(sentence_ids)
#             all_sentence_types.append(sentence_types)
#             all_sentence_attentions.append(sentence_attentions)
#     ids_tensor = torch.LongTensor(all_sentence_ids)
#     types_tensor = torch.LongTensor(all_sentence_types)
#     attention_tensor = torch.LongTensor(all_sentence_attentions)
#     seq_relationship_logits = model(ids_tensor, types_tensor, attention_tensor)    
#     predictions = []
#     for i, (line1, line2s) in enumerate(sampled_data_x.items()):
#         prediction_inx = seq_relationship_logits[i*num_options : (i + 1)*num_options, 0].argmax().tolist()
#         predictions.append((line1, line2s[prediction_inx]))
    

    with open('predfile_norhyme', 'w') as file_norhyme:
        with open('predfile_rhyme', 'w') as file_rhyme:
            for i, (line1, line2s) in enumerate(sampled_data_x.items()):
                line2 = predict_next_sentence(line1, line2s, tokenizer, model)
                file_norhyme.write(f'{line1}\t{line2}\n')
                line2 = predict_next_sentence(line1, line2s, tokenizer, model, rhyme=True)
                file_rhyme.write(f'{line1}\t{line2}\n')
                if (i + 1) % 10 == 0:
                    print(f'Finished predicting {i + 1} lines...')
    with open('goldfile', 'w') as file:
        for line1, line2 in sampled_data_y.items():
            file.write(f'{line1}\t{line2}\n')

In [0]:
import csv
import random
import argparse

def generate_predictions_one_song(args):
    
    num_songs = 10
    
    with open('goldfile', 'w') as file_gold:
            with open('predfile_random_onesong', 'w') as file_random:
                with open('predfile_norhyme_onesong', 'w') as file_norhyme:
                    with open('predfile_rhyme_onesong', 'w') as file_rhyme:
                        print('Deleting old files...')
                        
    for i in range(num_songs):
        chosen_row = None
        n = 1
        with open(args.datafile, encoding='utf8') as csv_file:
            csv_reader = csv.DictReader(csv_file)
            for row in csv_reader:
                if len(row['lyrics'].split('\n')) > 20:
                    continue
                if random.random() < 1 / n:
                    chosen_row = row
                n += 1

        lines = chosen_row['lyrics'].split('\n')   

        print(f"Chosen Song: {chosen_row['song']}")
        print()
        print('Lyrics:')
        print('\n'.join(lines))

        sampled_data_x = {}
        sampled_data_y = {}
        for i in range(len(lines) - 1):
            sampled_data_y[lines[i]] = lines[i + 1]
            sampled_data_x[lines[i]] = list(set(line for line in lines if line != lines[i]))

        with open('goldfile', 'a') as file_gold:
            with open('predfile_random_onesong', 'a') as file_random:
                with open('predfile_norhyme_onesong', 'a') as file_norhyme:
                    with open('predfile_rhyme_onesong', 'a') as file_rhyme:
                        for i, (line1, line2s) in enumerate(sampled_data_x.items()):
                            line2 = sampled_data_y[line1]
                            file_gold.write(f'{line1}\t{line2}\n')
                            line2 = random.choice(line2s)
                            file_random.write(f'{line1}\t{line2}\n')
                            line2 = predict_next_sentence(line1, line2s, tokenizer, model)
                            file_norhyme.write(f'{line1}\t{line2}\n')
                            line2 = predict_next_sentence(line1, line2s, tokenizer, model, rhyme=True)
                            file_rhyme.write(f'{line1}\t{line2}\n')
#                             if (i + 1) % 10 == 0:
#                                 print(f'Finished predicting {i + 1} lines...')

In [0]:
import csv
import random            
    
def generate_song(args):
    
    num_lines = args.num_lines
    num_choices = args.num_choices
    
    all_lines = []
    with open(args.datafile, encoding='utf8') as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            lines = row['lyrics'].split('\n')
            for i in range(len(lines)):
                all_lines.append(lines[i])
    
    song_lines = [random.choice(all_lines)]
    
    for i in range(num_lines - 1):
        lines = random.sample(all_lines, num_choices)
        next_line = predict_next_sentence(song_lines[-1], lines, tokenizer, model, rhyme=True)
        song_lines.append(next_line)

    with open('generate_song3.txt', 'w') as file_song:
        for line in song_lines:
            file_song.write(f'{line}\n')
       
    

In [0]:
parser = argparse.ArgumentParser()
parser.add_argument('--datafile', type=str, required=True)
parser.add_argument('--num-lines', type=int, required=True)
parser.add_argument('--num-choices', type=int, required=True)
args = parser.parse_args(['--datafile', 'english_rock.csv', '--num-lines', '10', '--num-choices', '1'])
generate_song(args)

In [0]:
import argparse
from nltk.translate.bleu_score import sentence_bleu
import warnings

warnings.simplefilter("ignore", UserWarning)

def bleuScore(gold, pred):
    cumulativeBlue, totalSentences = 0, len(gold)

    for line in gold:
        assert line in pred
        reference = [gold[line].split(' ')]
        candidate = pred[line].split(' ') 
        cumulativeBlue += sentence_bleu(reference, candidate, weights=(.334, 0.333, 0.333, 0))

    return cumulativeBlue / totalSentences

def accuracy(gold, pred):
    num_correct, num_total = 0, 0
    for line1 in gold:
        assert line1 in pred
        if gold[line1] == pred[line1]:
            num_correct += 1
        num_total += 1

    accuracy = num_correct / num_total

    return accuracy

def loadData(name):
    data = {}
    with open(name) as file:
        for line in file:
            line1, line2 = line.strip().split('\t')
            data[line1] = line2

    return data

def score_predictions(args):
    gold = loadData(args.goldfile)
    pred = loadData(args.predfile)

    assert len(gold) == len(pred)

    print(f'Accuracy: {accuracy(gold, pred):.2f}')
    print(f'BLEU score: {bleuScore(gold, pred):.2f}')

In [0]:
parser = argparse.ArgumentParser()
parser.add_argument('--datafile', type=str, required=True)
args = parser.parse_args(['--datafile', 'train_rock.csv'])
generate_predictions(args)

In [0]:
parser = argparse.ArgumentParser()
parser.add_argument('--datafile', type=str, required=True)
args = parser.parse_args(['--datafile', 'train_rock.csv'])
generate_predictions_one_song(args)

Deleting old files...
Chosen Song: dead-heat

Lyrics:
put them up against the wall
pull the trigger watch them fall
they can only feel the pain
stand them up and start again
tell the sargeant what you saw
fear the long arm of the law
even though it's hanging there
drugs and bad guys you'd better beware
of dead heat they're dead heat
if you shoot 'em down they'll be back on thier feet
they're dead heat they're dead heat
if you shoot 'em down they'll be back on the street
take 'em down contempt divine (?)
a cat that looks like frankenstein
he's holding up a jewellery store
listen to his bullets roar
their job is done they're all alone
they work their fingers to the bone
they're weary as they walk their beat
all day long they're dead on their feet
they're dead heat they're dead heat
certified zombies from their head to their feet
they're dead heat they're dead heat
if you shoot 'em down they'll be back on thier feet
wall rhymes with fall
fall rhymes with wall
pain rhymes with again
again 

In [0]:
parser = argparse.ArgumentParser()
parser.add_argument('--goldfile', type=str, required=True)
parser.add_argument('--predfile', type=str, required=True)
args = parser.parse_args(['--goldfile', 'goldfile', '--predfile', 'predfile_random_onesong'])
score_predictions(args)

parser = argparse.ArgumentParser()
parser.add_argument('--goldfile', type=str, required=True)
parser.add_argument('--predfile', type=str, required=True)
args = parser.parse_args(['--goldfile', 'goldfile', '--predfile', 'predfile_norhyme_onesong'])
score_predictions(args)

parser = argparse.ArgumentParser()
parser.add_argument('--goldfile', type=str, required=True)
parser.add_argument('--predfile', type=str, required=True)
args = parser.parse_args(['--goldfile', 'goldfile', '--predfile', 'predfile_rhyme_onesong'])
score_predictions(args)

Accuracy: 0.10
BLEU score: 0.22
Accuracy: 0.05
BLEU score: 0.12
Accuracy: 0.38
BLEU score: 0.56


In [0]:
parser = argparse.ArgumentParser()
parser.add_argument('--goldfile', type=str, required=True)
parser.add_argument('--predfile', type=str, required=True)
args = parser.parse_args(['--goldfile', 'goldfile', '--predfile', 'predfile_norhyme'])
score_predictions(args)

parser = argparse.ArgumentParser()
parser.add_argument('--goldfile', type=str, required=True)
parser.add_argument('--predfile', type=str, required=True)
args = parser.parse_args(['--goldfile', 'goldfile', '--predfile', 'predfile_rhyme'])
score_predictions(args)

Accuracy: 0.00
BLEU score: 0.15
Accuracy: 0.00
BLEU score: 0.15


In [0]:
import csv

loops = 0

with open('english_rock.csv', encoding='utf8') as csv_file:
    csv_reader = csv.DictReader(csv_file)
    for row in csv_reader:
        lines = row['lyrics'].split('\n')
        num_rhymes = 0
        for i in range(len(lines) - 1):
            if not lines[i].split() or not lines[i+1].split():
                continue
            if lines[i].split()[-1] in pronouncing.rhymes(lines[i+1].split()[-1]):
                num_rhymes += 1
        if num_rhymes / len(lines) > .3:
            print(f'Song: {row["song"]} Prob of rhyme: {num_rhymes / len(lines)}')
        loops += 1
        if loops > 100:
            continue
            break

Song: you-re-cracked Prob of rhyme: 0.5
Song: everybody-makes-me-barf Prob of rhyme: 0.6
Song: goodbye Prob of rhyme: 0.3333333333333333
Song: perfect-ten Prob of rhyme: 0.44
Song: don-t-wait Prob of rhyme: 0.3225806451612903
Song: paper-thin-hotel Prob of rhyme: 0.38461538461538464
Song: candy-came-back Prob of rhyme: 0.3333333333333333
Song: nostradamus Prob of rhyme: 0.375
Song: shah-of-shahs Prob of rhyme: 0.32142857142857145
Song: cleave-to-me Prob of rhyme: 0.3684210526315789
Song: you-give-love-a-bad-name Prob of rhyme: 0.34782608695652173
Song: heart-of-a-bad-machine Prob of rhyme: 0.36363636363636365
Song: automatic-thrill Prob of rhyme: 0.30952380952380953
Song: dog-day-dog-night Prob of rhyme: 0.45714285714285713
Song: lord-of-the-dusk Prob of rhyme: 0.36363636363636365
Song: little-man Prob of rhyme: 0.3125
Song: conniption-fit Prob of rhyme: 0.3076923076923077
Song: you-shout-you-shout-you-shout-you-shout Prob of rhyme: 0.3170731707317073
Song: mine-tonite Prob of rhyme: 0

KeyboardInterrupt: ignored

In [0]:
line_choices = [[lineB for lineB in lines if lineB != lineA] for lineA in lines]
short_lines = lines[:4]

In [0]:
next_lines = []
next_lines_gold = lines[1:]
for i in range(len(lines) - 1):
    next_line = predict_next_sentence(lines[i], line_choices[i], tokenizer, model)
    next_lines.append(next_line)

num_right = 0
num_total = 0
for pred, gold in zip(next_lines, next_lines_gold):
    num_total += 1
    if pred == gold:
        num_right += 1

print(f'Accuracy: {num_right / num_total}')
print(lines)
print(next_lines)
print(next_lines_gold)

Accuracy: 0.07894736842105263
["A lot of cats are hatin', slandering makin' bad statements", 'Mad cause they sit on their ass just stagnating', "Always vacillatin', now classmates I graduated with", "Are wonderin' how the stupidest kid up in the class made it", "Sick landscapin' and jammin' down in my mans basement", 'Getting restraints and complaints from mad neighbors', 'Now prejudice bigots say I sound just like them damn #%#', "Them pair of lenses ain't repairin' their impaired vision", "I'm on a mission escaping my own prison", "Inflicting more pain then you're givin' see I'm my own victim", "I can't believe I let you take up my time", "Take up space in mind, give it here, I'm takin' what's mine", '(Chrous)', 'The only ounce of power that I have', 'Is what I do with now', 'And how I let the hours pass', "I dont' know how long I'm gonna last", "So I can't let ya snatch the powder out my hourglass", "Everyday that I'm awake I face the angel of death", 'He may be taken my breath, so 

In [0]:
a = torch.LongTensor([[1,2],[3,2]])

In [0]:
a[:, 0] += 1

In [0]:
a

tensor([[2, 2],
        [4, 2]])