In [29]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")

from keras.preprocessing.sequence import pad_sequences

from transformers import XLNetTokenizer, XLNetForSequenceClassification, XLNetModel, AdamW

import pandas as pd
import json
import re
import string
from tqdm import tqdm

In [4]:
MAX_LEN = 1024
batch_size = 8

In [5]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=True)

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/760 [00:00<?, ?B/s]

In [6]:
model = XLNetForSequenceClassification.from_pretrained("xlnet-base-cased", num_labels=4)

Downloading:   0%|          | 0.00/445M [00:00<?, ?B/s]

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.bias', 'logits_proj.bias', 'logits_proj.weight', 'sequence_summary.summary.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

In [7]:
model.load_state_dict(torch.load('e_44_100.ckpt', map_location=device))
model = nn.DataParallel(model)
model.to(device)

DataParallel(
  (module): XLNetForSequenceClassification(
    (transformer): XLNetModel(
      (word_embedding): Embedding(32000, 768)
      (layer): ModuleList(
        (0): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForward(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (layer_1): Linear(in_features=768, out_features=3072, bias=True)
            (layer_2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (activation_function): GELUActivation()
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dr

In [8]:
def Dataloader(songs):
    tokenized_texts = [tokenizer.tokenize(song) for song in songs]
    input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
    input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
    attention_masks = []


    for seq in input_ids:
        seq_mask = [float(i>0) for i in seq]
        attention_masks.append(seq_mask)
    validation_inputs = torch.tensor(input_ids)
    validation_masks = torch.tensor(attention_masks)

    validation_data = TensorDataset(validation_inputs, validation_masks)
    validation_sampler = SequentialSampler(validation_data)
    validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)
    
    return validation_dataloader


In [9]:
def eval(validation_dataloader):
    model.eval()
    val_len = 0
    total_loss = 0
    predictions = []
    with torch.no_grad():
        for step, batch in enumerate(validation_dataloader):
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_input_mask = batch
            outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
            preds = outputs[0].detach().cpu().numpy()
            if step%100 == 0 and step:
                print("Step: %s" (step/len(validation_dataloader)))

            predictions.append(preds)
    return predictions

  print("Step: %s" (step/len(validation_dataloader)))


In [10]:
def clean(lyr):
    lyricstring = ""
    for line in lyr:
        if line != '':
            lyricstring = lyricstring + line + " "
    cleanr = re.compile('\[.*?\]|\(.*?\)')
    lyricstring = re.sub(cleanr, '', lyricstring.lower()) 
    punct = ""
    for p in string.punctuation:
        if p != '`' and p != "'":
            punct = punct + p
    lyricstring = "".join([char for char in lyricstring if char not in punct])
    
    return lyricstring

In [16]:
with open("track-lyrics.json") as f:
    track_lyrics = json.load(f)
track_lyrics = list(map(lambda x: x.split("\n")[1:], track_lyrics))

In [30]:
list_lyrics = []
for i in range(len(track_lyrics)):
    list_lyrics.append(clean(track_lyrics[i]))

100%|█████████████████████████████████████████████████████████████████| 3652/3652 [00:02<00:00, 1776.74it/s]


In [31]:
vd = Dataloader(list_lyrics)

In [None]:
arr = eval(vd)

In [24]:
print(arr)

[array([[ 0.55828905, -1.794007  ,  1.5331506 ,  0.834533  ],
       [-1.413558  , -0.02864932,  2.1458972 , -0.5030216 ],
       [-0.8137109 , -0.39609116,  2.4109771 , -0.72511405],
       [ 0.26034573, -0.8777653 ,  1.0417238 ,  0.35517716],
       [-1.344079  , -0.86249316,  4.7164974 , -1.7720084 ],
       [ 1.3024284 , -1.6338828 ,  0.3404919 ,  1.4989481 ],
       [ 0.20670183, -0.1348671 ,  0.45600075,  0.0731061 ],
       [ 0.83906615, -1.3575709 , -0.116785  ,  1.684588  ]],
      dtype=float32), array([[ 1.5307236 , -1.9505694 , -0.68550193,  2.5091457 ],
       [ 0.894444  , -1.7056776 , -0.19522905,  2.3239458 ]],
      dtype=float32)]


In [26]:
len(arr[0])

8