# Initialize framework

Enter the following code in the browser console to prevent Colab to suspend current session (may no longer work from March 2021 due to captcha)

```javascript
function ConnectButton(){
    console.log("Connect pushed"); 
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click();
}
let funId = setInterval(ConnectButton,60000);
```
\\
Enter the following line to cancel continuous clicking

```javascript
clearInterval(funId);
```


In [None]:
# Check python version because from 3.10 we can use match ... case
!python3 -V

Python 3.7.11


Import necessary libraries and try to use the GPU device (if available)

In [None]:
from typing import Callable, Optional, List, Union, Dict, Tuple
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import re
import random

import pandas as pd

# Set pandas option
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)
pd.set_option("display.max_colwidth", None)

# Try to use GPU device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using %s device" %(device))

Using cuda device


Mount Google Drive to load the dataset of lyrics and the word2vec pretrained embedding

In [None]:
# Mount drive
from google.colab import drive
drive.mount("/content/drive")

#dataset_path  = "/content/drive/MyDrive/DM project - NLP lyrics generation/english_cleaned_lyrics.csv"
#dataset_path  = "/content/drive/MyDrive/DM project - NLP lyrics generation/preprocessed_lyrics.csv"
#dataset_path  = "/content/drive/MyDrive/DM project - NLP lyrics generation/preprocessed_lyrics_0-4600 (LYRICS).csv" # 4600 lyrics and not artists (OLD PREPROCESSING)
dataset_path  = "/content/drive/MyDrive/DM project - NLP lyrics generation/preprocessed_lyrics_0-35000 (LYRICS).csv" # 35000 lyrics and not artists
word2vec_path = "/content/drive/MyDrive/DM project - NLP lyrics generation/cc.en.300.vec"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# **Word2vec and vocabulary**

Define
* a **dictionary of word indices** indexed by corresponding words,
* an **inverse dictionary of words** indexed by corresponding indices,
* a **list of word embedding vectors**,
* a **set of words in the lyrics**,
* three functions to easily retrieve
** the indices from the words,
** the words from the indices,
** the vector representation of the words

In [None]:
# Create dictionaries indexed by words and indices
words2indices = {}
indices2words = {}
# Create list for word embeddings vectors
word_vectors = []
# Create set for words in the lyrics
lyrics_vocab = set()

# Get word from index
def get_word_from_index(idx):
  # Use get to automatically return None if the index is not present in the dictionary
  return indices2words.get(idx)

# Get index from word
def get_index_from_word(word):
  # Use get to automatically return None if the word is not present in the dictionary
  return words2indices.get(word)

# Get word vector from word
def get_word_vector(word):
  idx = get_index_from_word(word)
  return word_vectors[idx] if idx != None else None

Define the function to create the vocabulary

In [None]:
# Create a set for storing words in the word2vec file but not in the lyrics
# in order to use them for adding missing words in the word2vec but in the lyrics
# by perturbating some words (such as word + 's) instead of using a random vector
skipped_words = {}

# Additional words not in the word2vec file to be added
additional_words = {"PAD", "UNK", "\n", "ain't"}

# Add word and corresponding vector to the vocabulary
def add_word(word, vector):
  # Calculate the word's index and add it and the word in the corresponding dictionaries
  word_idx = len(word_vectors)
  words2indices[word]     = word_idx
  indices2words[word_idx] = word

  # Append the word's vector to the list of vectors
  word_vectors.append(vector)

# Create vocabulary (i.e. populate the data structures previously defined)
def create_vocab():
  global words2indices, indices2words, word_vectors, lyrics_vocab

  # Calculate vocabulary length and define word embedding vector size
  VOCABULARY_LEN = len(lyrics_vocab) + len(additional_words)
  VECTOR_SIZE    = 300

  # Reset dictionary and inverse dictionary of words and word embedding vectors list
  words2indices.clear()
  indices2words.clear()
  word_vectors = []

  # Add additionally words to the vocabulary (these words are not present in the word2vec file)
  for word in additional_words:
    # Skip "ain't" word to elaborate next
    if word == "ain't":
      continue

    # Get random vector
    vector = torch.rand(VECTOR_SIZE)

    # Add word and corresponding vector to the vocabulary
    add_word(word, vector)



  # Word2vec file contains words (including punctuation and contractions) with some permutations of lower and upper case letters,
  # so don't worry about the case that a word may be present only in uppercase letters

  print("Creating vocabulary using Word2vec pre-trained embedding...\n")

  # Scan word2vec file
  with open(word2vec_path) as f:
      next(f) 
      for l in tqdm(f, total=2_000_000):
          # Check if the desired vocabulary size has been achieved
          if len(word_vectors) == VOCABULARY_LEN:
            print("Desired vocabulary size achieved, stop vectors iterations!")
            break

          # Get the word and the corresponding embedding vector
          word, *vector = l.strip().split()

          # Skip words not in the preprocessed lyrics
          if word in lyrics_vocab:
            #print("Added word: ", word)

            # Convert to float the word's vector values
            vector = torch.tensor([float(c) for c in vector])

            # Add word and corresponding vector to the vocabulary
            add_word(word, vector)

            # Check if the added word is "isn't"
            if word == "isn't":
              # Add "ain't" as a little perturbation of isn't vector
              vector = torch.rand(VECTOR_SIZE).clamp(-0.15, 0.15).add(vector)
                
              # Add word and corresponding vector to the vocabulary
              add_word("ain't", vector)

          # Check if the skipped word is a lowercase word (exclusively constituted by alpha chars)
          # or is equal to "'s" or "'d"
          elif re.match("^[a-z]+$|^'[sd]$", word):
            skipped_words[word] = torch.tensor([float(c) for c in vector])


  # Iterate the words in the lyrics to include in the vocabulary words not in the word2vec file
  # (for the moment include only words + 's or + 'd trying to perturbate existing vectors)
  for word in lyrics_vocab:
    # Check if the word is not in the vocabulary
    if word not in words2indices:
      # Check if the word is a word + 's or a word + 'd
      if re.match("^[a-z]+'[sd]$", word):
        #print("*"*40)
        #print(word)
        
        # Get the word before the apostrophe and the ending (i.e. "s" or "d")
        word_, ending = word.split("'")
        ending = "'" + ending
        
        #print("splitted_word: ", word_)
        #print("ending: ", ending)
        #print("is word in word2vec (words2indices)? ", word_ in words2indices)
        #print("is word in word2vec (skipped_words)? ", word_ in skipped_words)
        #print("is ending in word2vec (words2indices)? ", ending in words2indices)
        #print("is ending in word2vec (skipped_words)? ", ending in skipped_words)

        # Get the word2vec embedding of the word (before the apostrophe)
        vector_ = get_word_vector(word_)
        
        if vector_ == None:
          vector_ = skipped_words.get(word_)

        # Else use a random vector
        if vector_ == None:
          continue # TEMP SKIP

          vector_ = torch.rand(VECTOR_SIZE)

          # Add word and corresponding vector to the vocabulary
          add_word(word, vector_)
          continue

        # Get the word2vec embedding of the ending (i.e. 's or 'd)
        vector = get_word_vector(ending)

        if vector == None:
          vector = skipped_words.get(ending)
        
        # Else use a perturbation vector
        if vector == None:
            vector = torch.rand(VECTOR_SIZE).clamp(-0.03, 0.03)

        #print("word vector: ", vector_)
        #print("ending vector: ", vector)

        # Add the vector of the word and of the ending
        vector = vector_.add(vector)
          
        # Add word and corresponding vector to the vocabulary
        add_word(word, vector)

  # Convert the list of word embedding vectors to a tensor
  word_vectors = torch.stack(word_vectors)

  # Try to move word_vectors on the GPU 
  if torch.cuda.is_available():
    word_vectors.cuda()

  print("Vocabulary created")

Define a function to calculate cosine similarity between two word embedding vectors

In [None]:
# Compute cosine similarity between two word embedding vectors
def cosine_similarity(vector_a, vector_b):
  if vector_a == None or vector_b == None:
    return -1

  # Calculate the dot product of the inputs
  num = torch.sum((vector_a * vector_b))

  # Calculate the product of the inputs' norms
  den = torch.norm(vector_a) * torch.norm(vector_b)

  #print("num: ", num)
  #print("den: ", den)

  # Avoid division by zero
  if den == 0.0:
    den = 1e-8

  #print("num/den: ", num/den)
  return (num/den).item()

# **Dataset**

Create dataset class in which
* loading the preprocessed lyrics dataset,
* splitting lyrics in sequences

In [None]:
SEQUENCE_LENGTH = 6

one_hot_encoding_genres = {}
NUMBER_GENRES = 0

missing_words_sorted = None

class Dataset(Dataset):

    def __init__(self, dataset_path: str):
      self.dataset_path = dataset_path
      self.data_input_lyrics = []
      self.data_input_genres = []
      self.data_target       = []

      self.parse_data()

    # Create One Hot Encoding dictionary
    def create_ohe_dict(self, row, col_names):
      # Iterate columns names
      for col_name in col_names:
        # Check if the current col contains the 1 of the one hot encoding vector
        if row[col_name] == 1:
          # Try to split the col name to get only the original name (e.g. genre_Metal -> Metal)
          try:
            name = col_name.split('_')[1]
          except:
            name = col_name

          # Add the label and the corresponding ohe vector to the dictionary
          one_hot_encoding_genres[name] = [int(x) for x in row.values]
          return
    
    # Create lyrics vocabulary
    def create_lyrics_vocab(self, text):
      # Split the lyrics on whitespaces
      words = text.split(' ')

      # Remove empty strings (due to consecutive whitespaces)
      words = [w for w in words if w]

      # Add words to the lyrics vocabulary
      for word in words:
        lyrics_vocab.add(word)

    
    # Split lyrics into input sequences and target words by using a sliding window of words
    def lyrics_splitting(self, text):
      MAX_SEQ_NUMBER = 41
      
      # Calculate and limit the number of sequences in the text to not exhaust all available RAM/DISK
      text_splitted = text.split(' ')
      
      # Remove empty strings
      text_splitted = [w for w in text_splitted if w]
      sequences_available = len(text_splitted) - SEQUENCE_LENGTH
      sequences_number    = min(sequences_available, MAX_SEQ_NUMBER) - 1 # -1 for the target

      # Check if the lyrics contains exactly the required number of sequences
      if sequences_number < MAX_SEQ_NUMBER-1:
        return None

      sequences   = []
      targets     = []
      skipped_seq = []

      # Iterate sequences in the text
      for i in range(sequences_available):
        # Check if the number of sequences has been achieved
        if len(sequences) == sequences_number:
          break

        # Get a sequence composed of the current word plus subsequent SEQUENCE_LENGTH words
        seq_words = text_splitted[i:i + SEQUENCE_LENGTH]

        # Convert sequence from words to indices
        seq_indices = []
        for word in seq_words:
          index = get_index_from_word(word)
          
          # If the word is not in the vocabulary, break the loop
          if index == None:
            #print("Missing word in the vocabulary: ", word)
            break

          seq_indices.append(index)

        # Check if the converted sequence length is less than the SEQUENCE_LENGTH (i.e. a word in the sequence is not in the vocabulary)
        if len(seq_indices) < SEQUENCE_LENGTH:
          # Skip this sequence
          skipped_seq.append(i)

          # Evaluate entering something to reset LSTM states when adjacent sequences are not really such
          continue

        #print("i: ", i, " - seq_words: ", seq_words)
        #print("i: ", i, " - seq_indices: ", seq_indices)

        # Get the target word of the sequence
        target_word = text_splitted[i + SEQUENCE_LENGTH]

        # Convert to index
        target_index = get_index_from_word(target_word)

        # If the word is not in the vocabulary, skip the sequence
        if target_index == None:
          #print("Missing word in the vocabulary: ", print_word(word))
          
          skipped_seq.append(i)
          # Evaluate entering something to reset LSTM states when adjacent sequences are not really such
          continue

        # Append the sequence and the target in the corresponding lists
        sequences.append(seq_indices)
        targets.append(target_index)

        #print("i: ", i, " - seq_words: ", seq_words)
        #print("i: ", i, " - target_word: ", print_word(target_word))

      '''if sequences_number != len(sequences):
        print("number of sequences goal: ", sequences_number)
        print("recognized sequences: ", len(sequences))
        print("skipped sequences: ", len(skipped_seq))
        #print("skipped_seq: ", skipped_seq)
        #print("sequences: ", sequences)
        #print("targets: ", targets)
        print("*"*40)'''

      # Check if at least a sequence has been retrieved from the lyrics (this leads to list of sequences of different length, different number of sequences => padding)
      #if not sequences:
      # Check if the lyrics contains exactly the required number of sequences
      if len(sequences) != MAX_SEQ_NUMBER-1: 
        return None
      else:
        return [sequences, targets]


    def parse_data(self):
      global NUMBER_GENRES, missing_words_sorted
    
      # Read CSV file and get columns of interest
      print("Reading the dataset...")
      data = pd.read_csv(self.dataset_path)
      data = data[['genre', 'lyrics']].dropna()

      print("Dataset size: ", len(data.index))

      # Removing null rows consisting of non-english lyrics
      print("\nRemoving non-english lyrics detected during preprocessing...")
      data.dropna(inplace=True)

      #print("\ndata:")
      #print(data[:10])

      print("Dataset size: ", len(data.index))


      # Create vocabulary starting from the preprocessed lyrics
      print("\nCreating vocabulary from preprocessed lyrics...")
      # First create lyrics vocabulary
      data['lyrics'].apply(self.create_lyrics_vocab)
      # Then create vocabulary using pretrained word embedding and lyrics vocabulary
      create_vocab()

      # Print some words in the preprocessed lyrics but not in the word2vec file
      '''print("\nPrinting some words not in the word2vec file:")
      missing_words = set()
      for lyrics in data['lyrics']:
        l = lyrics.split()
        for word in l:
          if word not in words2indices and word not in missing_words: # and "n'" in word:
            missing_words.add(word)
            print("Text: ", lyrics)
            print("\nMissing word: ", word)
            print("*"*40)#'''

      # Print top 50 words in the preprocessed lyrics but not in the word2vec file in ascending order of occurrences
      print("\nPrinting top 50 words not in the word2vec file in ascending order of occurrences:")
      missing_words = {}

      # Create dictionary based on occurrences
      for lyrics in data['lyrics']:
        # Split lyrics
        l = lyrics.split(' ')
        # Remove empty strings
        l = [w for w in l if w]

        # For each word not in the word2vec file, count occurrences and store a lyrics
        for word in l:
          if word not in words2indices:
            if word in missing_words:
              missing_words[word][0] += 1
            else:
              missing_words[word] = [1, lyrics]
      
      # Sort words by occurrences in descending order
      missing_words_sorted = sorted(missing_words.items(), key=lambda x: x[1][0], reverse=True)

      for i in range(50):
        word, l = missing_words_sorted[i]
        occurrences, lyrics = l

        print("Text: ", lyrics)
        print("\nMissing word: ", word)
        print("Missing word occurrences: ", occurrences)
        print("*"*40)


      print("lyrics_vocab size: ", len(lyrics_vocab))
      print("word_vectors size: ", len(word_vectors))
      print("Missing words in word_vectors = lyrics_vocab size - word_vectors size: ", len(lyrics_vocab) - len(word_vectors))


      # Split lyrics into input and target sequences of SEQUENCE_LENGTH words
      print("Splitting lyrics into sequences...")
      data['lyrics'] = data['lyrics'].apply(self.lyrics_splitting)

      # Remove null rows (i.e. lyrics that don't satisfy splitting requirements)
      print("Removing null rows...")
      size_with_na = len(data.index)
      data.dropna(inplace=True)
      print("Removed lines: %d/%d => %.2f%% " %(size_with_na - len(data.index), size_with_na, (1 - len(data.index)/size_with_na)*100))

      # Count occurrences (i.e. lyrics) per genre in the dataset
      print("\nLyrics per genre:")
      print(data['genre'].value_counts())

      # One hot encode genres
      print("\nOne hot encoding genres...")
      data = pd.get_dummies(data, columns=['genre'])

      #print("\ndata:")
      #print(data[:10])

      # Get genres without duplicates
      genres_without_duplicates = data.iloc[:, 1:].drop_duplicates()

      #print("\genres_without_duplicates:")
      #print(genres_without_duplicates)

      genres_without_duplicates.apply(self.create_ohe_dict, axis=1, col_names=genres_without_duplicates.columns)
      
      NUMBER_GENRES = len(one_hot_encoding_genres)
      print("Number of unique genres:", NUMBER_GENRES)

      #print("One hot encoding dictionary:")
      #print(one_hot_encoding_genres)

      # Get input and target
      input = pd.DataFrame()
      input['lyrics'] = data['lyrics'].apply(lambda x: x[0])
      input['genre']  = data.apply(lambda x: x.values[1:], axis=1)
      target = data['lyrics'].apply(lambda x: x[1])

      print("\ninput:")
      print(input[:10])
      print("\ntarget:")
      print(target[:10])

      self.data_input_lyrics  = torch.tensor(input['lyrics'], dtype=torch.long, device=device)
      self.data_input_genres  = torch.tensor(input['genre'], dtype=torch.long, device=device)
      self.data_target = torch.tensor(target, dtype=torch.long, device=device)

      print("input (lyrics) tensor size:", self.data_input_lyrics.size())
      print("input (genres) tensor size:", self.data_input_genres.size())
      print("target tensor size:", self.data_target.size())

    def __len__(self):
      return len(self.data_input_lyrics)

    def __getitem__(self, idx):
      return [(self.data_input_lyrics[idx], self.data_input_genres[idx]), self.data_target[idx]]

In [None]:
# Instantiate train dataset
train_dataset = Dataset(dataset_path=dataset_path)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)


# Code to split dataset in train and test
# Option 1
'''
dataset = Dataset(dataset_path=file_path)
batch_size      = 64
test_split      = .2
shuffle_dataset = True
random_seed     = 42

# Get split index
dataset_size = len(dataset)
indices = list(range(dataset_size))
split_index = int(np.floor(test_split * dataset_size))

# Shuffle dataset (if necessary)
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)

# Get data indices for training and test splits
train_indices = indices[split_index:]
test_indices  = indices[:split_index]

# Create data samplers and loaders
train_sampler = SubsetRandomSampler(train_indices)
test_sampler  = SubsetRandomSampler(test_indices)

train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
test_loader  = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
'''

# Option 2
'''
dataset = Dataset(dataset_path=file_path)
dataset_size = len(dataset)
BATCH_SIZE   = 64
TRAIN_SPLIT  = .8
SEED         = 2147483647

train_size = int(train_split * dataset_size)
test_size  = dataset_size - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(random_seed))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader  = DataLoader(test_dataset, batch_size=batch_size)
'''

Reading the dataset...
Dataset size:  2000

Removing non-english lyrics detected during preprocessing...
Dataset size:  2000

Creating vocabulary from preprocessed lyrics...
Creating vocabulary using Word2vec pre-trained embedding...



  0%|          | 0/2000000 [00:00<?, ?it/s]

Vocabulary created

Printing top 50 words not in the word2vec file in ascending order of occurrences:
Text:  intro :  
 kid's voice :  
 wise ... wise .... blessed ... blessed ...  
 more fire !  !  ... more fire !  !  .... red 
 capleton :  
 more fire !  !  !  red hot !  !  !  !  yuh see ah now well done ,  yo !  !  
 light up di fire from mi ready fi put it pon dem 
 hey gimme di hey gimme di yo !  !  !  !  
 chorus :  
 dat one yah name cuyah cuyah cuyah when dem see mi wid di fire 
 cuyah cuyah cuyah when mi bun di vampire 
 cuyah cuyah cuyah when mi bun di obeah wuker 
 tell dem cuyah !  !  !  ... cuyah !  !  !  ... cuyah !  !  !  ... again 
 cuyah cuyah cuyah when dem see mi wid di fire 
 cuyah cuyah cuyah when mi bun di obeah wuker 
 cuyah cuyah cuyah when mi bun di vampire 
 tell dem cuyah !  !  !  ... cuyah !  !  !  ... cuyah !  !  !  
 well there is ... nuttin to nuh worry about 
 there is ... nuttin to confuse about 
 there is ,  when mi seh nuttin to nuh carry about 
 caus

'\ndataset = Dataset(dataset_path=file_path)\ndataset_size = len(dataset)\nBATCH_SIZE   = 64\nTRAIN_SPLIT  = .8\nSEED         = 2147483647\n\ntrain_size = int(train_split * dataset_size)\ntest_size  = dataset_size - train_size\ntrain_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(random_seed))\n\ntrain_dataloader = DataLoader(train_dataset, batch_size=batch_size)\ntest_dataloader  = DataLoader(test_dataset, batch_size=batch_size)\n'

Calculate cosine similarity between two words

In [None]:
class string_style:
  BOLD    = "\033[1m"
  ITALICS = "\033[3m"
  END     = "\033[0m"

# Take two random words in the vocabulary
word_0 = random.choice(indices2words)
word_1 = random.choice(indices2words)

# Retrieve corresponding word embedding vectors
word_vector_0 = get_word_vector(word_0)
word_vector_1 = get_word_vector(word_1)

#print("%s vec: %r" %(word_0, word_vector_0))
#print("%s vec: %r" %(word_1, word_vector_1))

# Style subsequent word printings
word_0 = string_style.BOLD + word_0 + string_style.END
word_1 = string_style.BOLD + word_1 + string_style.END

# Calculate cosine similarity through custom function
cos_sim = cosine_similarity(word_vector_0, word_vector_1)
print("cosine similarity of %s and %s: %.3f" %(word_0, word_1, cos_sim))

# Calculate cosine similarity through torch function
if word_vector_0 != None and word_vector_1 != None:
  cos_sim_fn = torch.nn.CosineSimilarity(dim=0)
  cos_sim = cos_sim_fn(word_vector_0, word_vector_1)
  print("cosine similarity of %s and %s: %.3f" %(word_0, word_1, cos_sim))

cosine similarity of [1mquickie[0m and [1mimmune[0m: 0.004
cosine similarity of [1mquickie[0m and [1mimmune[0m: 0.004


Display some words in the preprocessed lyrics but not in the word2vec file


In [None]:
print("Words not in the word2vec file:")
for word in lyrics_vocab:
  if word not in words2indices:
    print(word)

Words not in the word2vec file:
r&b
blackbir
strangeling
russkies
'90s
nuffer
mepart
iwisa
droller
nkhopotsa
around'
allajjliiiineeed
paaassion
goin'to
so'
lemonadea
nasceva
selassi
pine'
mebeyoncé
we�re
dubane
wabona
bangled
#5
gask
redda
babymother
nuttyy
zikhand'
linched
sports'
pornpage
muhfucking
renumerate
borogroves
kerotine
downtrodden‚
daynas
vsst
fareweeli
sognavo
blessins
griffy
lonke
wctu
sitay'
soget
bwal
woooiii
mammi
innisfail
chalwa
anthemt
youve
tarbuck
f'ed
comptowns
youselfint
dodda
ii'm
tswalang
sognava
quo'
girls'
fak'
geerie
skhwama
oxed
a'
brrrrrup
sinting
bgko
spane
tapsalteerie
heeeeeeey
sugartooth
supahype
torea
geppetto
knoccie
them‚
i'
ndawo
hmmhmm
appomattox
umazi
zami
rockity
bughtin
pendles
ritis
sinder
dallaz
gooving
unended
reeeeeeeeeeeeeeeeddd
fabbin
kimmers
shoulodn't
ill
nahâ€¦
loo'ed
toasta
familyly
gungster
tshembi
couldah
ngavavenge
you'
pipie's
athol
remenance
santificado
j&b
skiiirt
say'
machankura
'hey'
voni
is'cathulo
angels'
rocafella
piran

In [None]:
print("Words not in the word2vec file and occurrences:")
for word, l in missing_words_sorted:
  occurrences, lyrics = l

  print("Missing word: ", word)
  print("Occurrences: ", occurrences)
  print("*"*40)

Words not in the word2vec file and occurrences:
Missing word:  cuyah
Occurrences:  60
****************************************
Missing word:  o'
Occurrences:  58
****************************************
Missing word:  reflektor
Occurrences:  51
****************************************
Missing word:  a'
Occurrences:  42
****************************************
Missing word:  wha'
Occurrences:  37
****************************************
Missing word:  wayase
Occurrences:  34
****************************************
Missing word:  shooby
Occurrences:  33
****************************************
Missing word:  &c
Occurrences:  27
****************************************
Missing word:  unuh
Occurrences:  26
****************************************
Missing word:  wi'
Occurrences:  26
****************************************
Missing word:  bus'
Occurrences:  24
****************************************
Missing word:  'bou'
Occurrences:  21
****************************************
Missing word

# **Models**

Define the generator neural network

In [None]:
class Generator(nn.Module):

  def __init__(
      self,
      word_vectors: torch.Tensor,
      lstm_hidden_size: int,
      dense_size: int,
      #hidden_size: int,
      vocab_size: int
  ):
    super().__init__()
    
    # Embedding layer
    self.embedding = torch.nn.Embedding.from_pretrained(word_vectors)
    
    # Recurrent layer (LSTM)
    self.rnn = torch.nn.LSTM(input_size=word_vectors.size(1), hidden_size=lstm_hidden_size, num_layers=1, batch_first=True)

    # Dense layer
    self.dense = torch.nn.Linear(dense_size, vocab_size)
    torch.nn.init.uniform_(self.dense.weight)

    # Dropout function
    self.dropout = nn.Dropout(p=0.1)
		
		# Loss function
    self.loss = torch.nn.CrossEntropyLoss()
    
    self.global_epoch = 0
    
  def forward(self, x, y=None, states=None):
    # Split input in lyrics and genre
    lyrics = x[0]
    genres = x[1]

    # Embedding words from indices
    out = self.embedding(lyrics)
    #print("embedding_out: ", out)
    #print("embedding_out size: ", out.size())

    # Recurrent layer
    out, states  = self.rnn(out, states)
    #print("recurrent_out: ", out)
    #print("recurrent_out size: ", out.size())
    #print("states: ", states)
    #hidden_state = states[0]
    #print("hidden_state size: ", hidden_state.size())
    #cell_state = states[1]
    #print("cell_state size: ", cell_state.size())

    # Duplicate the genre vector associated to a sequence for each word in the sequence
    seq_length = lyrics.size()[1]

    if seq_length > 1:
      genres_duplicated = []
      for tensor in genres:
        duplicated = [list(tensor) for i in range(seq_length)]
        genres_duplicated.append(duplicated)

      genres = torch.tensor(genres_duplicated, device=device)
    else:
      # Just increment the genres vector dimension
      genres = genres.unsqueeze(0)

    #print("genres:")
    #print(genres)
    #print("genres size: ", genres.size())

    # Concatenate the LSTM output with the encoding of genres
    out = torch.cat((out, genres), dim=-1)
    #print("out:")
    #print(out)
    #print("out size: ", out.size())

    # Dense layer
    out = self.dense(out)
    #print("dense_out: ", out)
    #print("dense_out size: ", out.size())

    # Use the last prediction
    logits = out[:, -1, :]
    #print("dense_out (last pred): ", logits)
    #print("dense_out size (last pred): ", logits.size())
    
    # Scale logits in [0,1] to avoid negative logits
    logits = torch.softmax(logits, dim=-1)

    #print("dense_out (last pred) after activation: ", logits)
    #print("dense_out (last pred) after activation size: ", logits.size())
    #print("dense_out (last pred) after activation type: ", logits.type())


    # Max likelihood can return repeated sequences over and over.
    # Sample from the multinomial probability distribution of 'logits' (after softmax). 
    # Return the index of the sample (one for each row of the input matrix) 
    # that corresponds to the index in the vocabulary as logits are calculated on the whole vocabulary
    sampled_indices = torch.multinomial(logits, num_samples=1)

    #print("sampled_indices: ", sampled_indices)
    #print("sampled_indices size: ", sampled_indices.size())
    
    result = {'logits': logits, 'pred': sampled_indices, 'states': states}
    
    if y is not None:
      #print("y: ", y)
      #print("y size: ", y.size())

      result['loss']     = self.loss(logits, y)
      result['accuracy'] = self.accuracy(sampled_indices, y.unsqueeze(-1))
      
    return result

  def accuracy(self, pred, target):
    return torch.sum(pred == target) / pred.size()[0]

Define the discriminator neural network

In [None]:
class Discriminator(nn.Module):
  
  def __init__(
    self,
		word_vectors: torch.Tensor,
		lstm_hidden_size: int,
    dense_size: int#,
		#hidden_size: int
	):
    super().__init__()
    
    # Embedding layer
    self.embedding = torch.nn.Embedding.from_pretrained(word_vectors)
    
    # Recurrent layer (LSTM)
    #self.rnn = torch.nn.LSTM(input_size=word_vectors.size(1), hidden_size=lstm_hidden_size, num_layers=1, batch_first=True, bidirectional=True)
    self.rnn = torch.nn.LSTM(input_size=word_vectors.size(1), hidden_size=lstm_hidden_size, num_layers=3, batch_first=True)
    
    # Dense layer
    self.dense = torch.nn.Linear(dense_size, 1)
    torch.nn.init.uniform_(self.dense.weight)
    
    # Dropout function
    self.dropout = nn.Dropout(p=0.1)
    
    # Activation function
    self.out_act = torch.nn.Sigmoid()
    
    # Loss function
    self.loss = torch.nn.BCELoss()

    self.global_epoch = 0
    
  def forward(self, x, y=None):
    # Split input in lyrics and genre
    lyrics = x[0]
    genres = x[1]
    
    # Embedding words from indices
    out = self.embedding(lyrics)
    #print("out: ", out.size())

    # Recurrent layer
    out = self.rnn(embedding_out)[0]
    #print("out rnn: ", out.size())
    
    # Duplicate the genre vector associated to a sequence for each word in the sequence
    seq_length = lyrics.size()[1]

    if seq_length > 1:
      genres_duplicated = []
      for tensor in genres:
        duplicated = [list(tensor) for i in range(seq_length)]
        genres_duplicated.append(duplicated)

      genres = torch.tensor(genres_duplicated, device=device)
    else:
      # Just increment the genres vector dimension
      genres = genres.unsqueeze(0)
    
    # Concatenate the LSTM output with the encoding of genres
    out = torch.cat((out, genres), dim=-1)
    #print("out:")
    #print(out)
    #print("out size: ", out.size())
    
    # Dense layer
    out = self.dense(out)

    # Use the last prediction
    out = out[:,-1,:]
    #print("dense_out (last pred): ", out)
    #print("dense_out size (last pred): ", out.size())
    
    # Sigmoid activation function
    logits = self.out_act(out)
    #print("dense_out after activation fn (last pred): ", logits)
    #print("dense_out after activation fn (last pred): ", logits.size())
    
    # Round logits to return one of the two classes (0 and 1) for each entry
    pred = torch.round(logits)
    #print("dense_out after round (last pred): ", pred)
    #print("dense_out after round size (last pred): ", pred.size())
    
    result = {'logits': logits, 'pred': pred}
    
    if y is not None:
      y = y.unsqueeze(-1)
      result['loss']     = self.loss(logits, y)
      result['accuracy'] = self.accuracy(pred, y)
      
    return result

  def accuracy(self, pred, target):
    return torch.sum(pred == target) / pred.size()[0]

Instantiate the models and the optimizers

In [None]:
############ Parameters: #############

GEN_LSTM_SIZE   = 256
GEN_DENSE_SIZE  = 256 + NUMBER_GENRES
GEN_HIDDEN_SIZE = 2048

DIS_LSTM_SIZE   = 256
DIS_DENSE_SIZE  = 256 + NUMBER_GENRES
DIS_HIDDEN_SIZE = 64

# NO GAN
#LR = 0.01

# GAN
# Learning rate for optimizers
LR = 0.0002
# Betas hyperparam for Adam optimizers
BETAS = (0.5, 0.999)

######################################

gen = Generator(
    word_vectors,
    lstm_hidden_size=GEN_LSTM_SIZE,
    dense_size=GEN_DENSE_SIZE, 
    #hidden_size=GEN_HIDDEN_SIZE,
    vocab_size=len(word_vectors))

dis = Discriminator(
    word_vectors,
    lstm_hidden_size=DIS_LSTM_SIZE,
    dense_size=DIS_DENSE_SIZE)#,
    #hidden_size=DIS_HIDDEN_SIZE)

gen_optimizer = torch.optim.Adam(gen.parameters(), lr=LR, betas=BETAS)
dis_optimizer = torch.optim.Adam(dis.parameters(), lr=LR, betas=BETAS)

# Try to move the models on the GPU 
if torch.cuda.is_available():
  gen.cuda()
  dis.cuda()

# Print the models summaries
print(gen)
print(dis)

Generator(
  (embedding): Embedding(17626, 300)
  (rnn): LSTM(300, 256, batch_first=True)
  (lin1): Linear(in_features=266, out_features=17626, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (loss): CrossEntropyLoss()
)
Discriminator(
  (embedding): Embedding(17626, 300)
  (rnn): LSTM(300, 256, num_layers=3, batch_first=True)
  (linear_one): Linear(in_features=266, out_features=1, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (loss): BCELoss()
  (out_act): Sigmoid()
)


# **Training loop**

Train only the generator

In [None]:
PATH = '/content/drive/MyDrive/DM project - NLP lyrics generation/generator_model.pt'

def train_and_evaluate_gen(
    generator:     nn.Module,
    discriminator: nn.Module,
    gen_optimizer: torch.optim.Optimizer,
    dis_optimizer: torch.optim.Optimizer,
    train_dataloader: DataLoader,
    epochs: int = 5,
    start_epoch: int = 0, 
    verbose: bool = True
):
    gen_train_history = []
    dis_train_history = []

    # Iterate epochs starting from start_epoch
    for epoch in range(start_epoch, epochs):
        gen_losses     = []
        gen_accuracies = []

        # Iterate batches of the training set
        for x, y in tqdm(train_dataloader):
          #print("x: ", x)
          #print("y: ", y)
          #print("\nx size: ", x.size())
          #print("y size: ", y.size())

          #print("\nx[0] size: ", x[0].size())
          #print("\nx[1] size: ", x[1].size())

          # Get the number of sequences in each batch
          input_sequences = x[0]
          input_genres    = x[1]
          seq_numbers = input_sequences.size()[1]

          # Init generator states for the LSTM
          gen_states = None

          # Iterate sequences and targets of the batch
          for i in range(seq_numbers):
            # Get the i-th sequence (column) and the i-th target (column)
            input  = input_sequences[:, i]
            target = y[:, i]

            #print("input: ", input)
            #print("target: ", target)
            #print("input size: ", input.size())

            # Get input batch size
            input_batch_size = input_sequences.size()[0]

            #=============================
            # TRAIN ONLY THE GENERATOR WITHOUT USING THE DISCRIMINATOR

            # Wrap hidden states in new tensors to detach them from their history
            def repackage_states(states):
              if isinstance(states, torch.Tensor):
                  return states.detach()
              else:
                  return tuple(repackage_states(item) for item in states)

            # Zero the gradients on each iteration
            gen_optimizer.zero_grad()
            
            # Run the generator
            gen_out = generator((input, input_genres), target, gen_states)

            # Get states, loss and accuracy
            gen_states   = gen_out['states']
            gen_loss     = gen_out['loss']
            gen_accuracy = gen_out['accuracy']

            # Print input, target and generated words
            #in_sequences = []
            #for seq in input:
            #  words = []
            #  for idx in seq:
            #    word = get_word_from_index(idx.item())
            #    words.append(word)
            #  in_sequences.append(words)
            
            #print("input words:\t%r" %in_sequences)
            #print("target word:\t%r" %([get_word_from_index(idx.item()) for idx in target]))
            #print("word generated:\t%r" %([get_word_from_index(idx.item()) for idx in gen_out['pred']]))
            #print("loss: %.3f" %gen_loss.item())
            #print("accuracy: %.3f" %gen_accuracy.item())
            #print("*"*40)

            # Repackage generator states for the LSTM
            gen_states = repackage_states(gen_states)

            # Backpropagate the loss through the neural network
            gen_loss.backward()
            # Initiate gradient descent
            gen_optimizer.step()

            # Track progress
            gen_losses.append(gen_loss)
            gen_accuracies.append(gen_accuracy)
            #=============================

        # Print input, target and generated words
        #in_sequences = []
        #for seq in input:
        #  words = []
        #  for idx in seq:
        #    word = get_word_from_index(idx.item())
        #    words.append(word)
        #  in_sequences.append(words)
        
        #print("input words:\t%r" %in_sequences)
        #print("target word:\t%r" %([get_word_from_index(idx.item()) for idx in target]))
        #print("word generated:\t%r" %([get_word_from_index(idx.item()) for idx in gen_out['pred']]))
        #print("loss: %.3f" %gen_loss.item())
        #print("accuracy: %.3f" %gen_accuracy.item())
        #print("*"*40)

        gen.global_epoch += 1

        gen_mean_loss     = sum(gen_losses) / len(gen_losses)
        gen_mean_accuracy = sum(gen_accuracies) / len(gen_accuracies)
        gen_train_history.append(gen_mean_loss.item())
        if verbose or epoch == epochs - 1:
            print(f'  Epoch {gen.global_epoch:3d} => Generator loss: {gen_mean_loss:0.6f}, Generator accuracy: {gen_mean_accuracy:0.3f} = {gen_mean_accuracy*100:0.3f}')

        # Save generator model (checkpoint)
        torch.save({
            'epoch': epoch,
            'model_state_dict': gen.state_dict(),
            'optimizer_state_dict': gen_optimizer.state_dict(),
            'loss': gen_mean_loss,
            'accuracy': gen_mean_accuracy
            }, PATH)

    return {
        'gen_train_history': gen_train_history,
        'dis_train_history': dis_train_history
    }

Train only the discriminator

In [None]:
PATH = '/content/drive/MyDrive/DM project - NLP lyrics generation/discriminator_model.pt'

def train_and_evaluate_dis(
    generator:     nn.Module,
    discriminator: nn.Module,
    gen_optimizer: torch.optim.Optimizer,
    dis_optimizer: torch.optim.Optimizer,
    train_dataloader: DataLoader,
    epochs: int = 5,
    start_epoch: int = 0, 
    verbose: bool = True
):
    gen_train_history = []
    dis_train_history = []

    # Iterate epochs starting from start_epoch
    for epoch in range(start_epoch, epochs):
        dis_losses     = []
        dis_accuracies = []
        
        # Iterate batches of the training set
        for x, y in tqdm(train_dataloader):
          #print("x: ", x)
          #print("y: ", y)
          #print("\nx size: ", x.size())
          #print("y size: ", y.size())

          #print("\nx[0] size: ", x[0].size())
          #print("\nx[1] size: ", x[1].size())

          # Get the number of sequences in each batch
          input_sequences = x[0]
          input_genres    = x[1]
          seq_numbers = input_sequences.size()[1]

          # Init generator states for the LSTM
          gen_states = None

          # Iterate sequences and targets of the batch
          for i in range(seq_numbers):
            # Get the i-th sequence (column) and the i-th target (column)
            input  = input_sequences[:, i]
            target = y[:, i]

            #print("input: ", input)
            #print("target: ", target)
            #print("input size: ", input.size())

            # Get input batch size
            input_batch_size = input_sequences.size()[0]

            # Init real and fake targets for discriminator
            target_real = torch.ones([input_batch_size], device=device)
            target_fake = torch.zeros([input_batch_size], device=device)

            #=================================================
            # GENERATE A FAKE SEQUENCE

            VOCABULARY_SIZE = len(words2indices)
            fake = []

            # Take input_batch_size x SEQUENCE_LENGTH random indices from the vocabulary without using the generator model
            for i in range(input_batch_size):
              f = []

              for i in range(SEQUENCE_LENGTH):
                word_index = random.randrange(3, VOCABULARY_SIZE) # Skip PAD, UNK indices
                f.append(word_index)

              fake.append(f)

            fake = torch.tensor(fake, device=device)
            #=================================================
            
            #=================================================
            # TRAIN ONLY THE DISCRIMINATOR

            # Zero the gradients on each iteration
            dis_optimizer.zero_grad()

            # Run the discriminator on real sequences
            discriminator_out_real = discriminator((input, input_genres), target_real)
            discriminator_loss_real = discriminator_out_real['loss']
            discriminator_acc_real  = discriminator_out_real['accuracy']

            # Run the discriminator on fake sequences
            discriminator_out_fake = discriminator((fake, input_genres), target_fake)
            discriminator_loss_fake = discriminator_out_fake['loss']
            discriminator_acc_fake  = discriminator_out_fake['accuracy']
            
            # Sum losses and accuracies
            dis_loss     = discriminator_loss_real + discriminator_loss_fake
            dis_accuracy = (discriminator_acc_real + discriminator_acc_fake) / 2

            # Print input and generated words
            def convert(input):
              in_sequences = []
              for seq in input:
                words = []
                for idx in seq:
                  word = get_word_from_index(idx.item())
                  words.append(word)
                in_sequences.append(words)
              return in_sequences
            
            '''print("input words:\t%r" %convert(input))
            print("discriminator prediction (true):\t%r" %([int(x) for x in discriminator_out_true['pred']]))
            print("fake words:\t%r" %convert(fake))
            print("discriminator prediction (fake):\t%r" %([int(x) for x in discriminator_out_fake['pred']]))
            print("loss (sum of the losses): %.3f" %dis_loss.item())
            print("accuracy (sum of the accuracies): %.3f" %dis_accuracy.item())
            print("loss (true): %.3f" %discriminator_loss_true.item())
            print("loss (fake): %.3f" %discriminator_loss_fake.item())
            print("accuracy (true): %.3f" %discriminator_acc_true.item())
            print("accuracy (fake): %.3f" %discriminator_acc_fake.item())
            print("*"*40)'''

            dis_loss.backward()
            dis_optimizer.step()

            # Track progress
            dis_losses.append(dis_loss)
            dis_accuracies.append(dis_accuracy)#'''
            #=================================================

        '''print("input words:\t%r" %convert(input))
        print("discriminator prediction (true):\t%r" %([int(x) for x in discriminator_out_true['pred']]))
        print("fake words:\t%r" %convert(fake))
        print("discriminator prediction (fake):\t%r" %([int(x) for x in discriminator_out_fake['pred']]))
        print("loss (sum of the losses): %.3f" %dis_loss.item())
        print("accuracy (sum of the accuracies): %.3f" %dis_accuracy.item())
        print("loss (true): %.3f" %discriminator_loss_true.item())
        print("loss (fake): %.3f" %discriminator_loss_fake.item())
        print("accuracy (true): %.3f" %discriminator_acc_true.item())
        print("accuracy (fake): %.3f" %discriminator_acc_fake.item())
        print("*"*40)'''

        dis.global_epoch += 1

        dis_mean_loss     = sum(dis_losses) / len(dis_losses)
        dis_mean_accuracy = sum(dis_accuracies) / len(dis_accuracies)
        dis_train_history.append(dis_mean_loss.item())
        if verbose or epoch == epochs - 1:
            print(f'  Epoch {gen.global_epoch:3d} => Discriminator loss: {dis_mean_loss:0.6f}, Discriminator accuracy: {dis_mean_accuracy:0.3f} = {dis_mean_accuracy*100:0.3f}')

        # Save discriminator model (checkpoint)
        torch.save({
            'epoch': epoch,
            'model_state_dict': discriminator.state_dict(),
            'optimizer_state_dict': dis_optimizer.state_dict(),
            'loss': dis_mean_loss
            }, PATH)

    return {
        'gen_train_history': gen_train_history,
        'dis_train_history': dis_train_history
    }

Train the GAN

In [None]:
#PATH = '/content/drive/MyDrive/DM project - NLP lyrics generation/generator_model_GAN.pt' # train gen and then dis
PATH = '/content/drive/MyDrive/DM project - NLP lyrics generation/generator_model_GAN_v3.pt' # train dis and then gen but using different lr and betas


# Wrap hidden states in new Tensors, to detach them from their history
def repackage_states(states):
  if isinstance(states, torch.Tensor):
      return states.detach()
  else:
      return tuple(repackage_states(item) for item in states)

# Convert input sequences of indices in sequences of words
def convert(input):
  in_sequences = []
  for seq in input:
    words = []
    for idx in seq:
      word = get_word_from_index(idx.item())
      words.append(word)
    in_sequences.append(words)
  return in_sequences


def train_and_evaluate_GAN(
    generator:     nn.Module,
    discriminator: nn.Module,
    gen_optimizer: torch.optim.Optimizer,
    dis_optimizer: torch.optim.Optimizer,
    train_dataloader: DataLoader,
    start_epoch: int = 0, 
    epochs: int = 5,
    verbose: bool = True,
):
    gen_train_history = []
    dis_train_history = []

    # Iterate epochs starting from start_epoch
    for epoch in range(start_epoch, epochs):
        gen_losses = []
        dis_losses = []
        gen_accuracies = []
        dis_accuracies = []

        j = 0

        # Iterate batches of the training set
        for x, y in tqdm(train_dataloader):
          #print("x: ", x)
          #print("y: ", y)
          #print("\nx size: ", x.size())
          #print("y size: ", y.size())

          #print("\nx[0] size: ", x[0].size())
          #print("\nx[1] size: ", x[1].size())

          # Get the number of sequences in each batch
          input_sequences = x[0]
          input_genres    = x[1]
          seq_numbers = input_sequences.size()[1]
          #==========================================

          # Init generator states for the LSTM
          gen_states = None

          # Iterate sequences and targets of the batch
          for i in range(seq_numbers):
            # Get the i-th sequence (column) and the i-th target (column)
            input  = input_sequences[:, i]
            target = y[:, i]

            #print("input: ", input)
            #print("target: ", target)
            #print("input size: ", input.size())

            # Get input batch size
            input_batch_size = input_sequences.size()[0]

            # Init real and fake targets for discriminator
            target_real = torch.ones([input_batch_size], device=device)
            target_fake = torch.zeros([input_batch_size], device=device)


            #=================================================
            # GENERATE A FAKE SEQUENCE

            fake = input
            fake_states = None

            # SEQUENCE_LENGTH words generation through moving window
            for i in range(SEQUENCE_LENGTH):
              if i == 0:
                gen_out = generator((input, input_genres), target, gen_states)

                gen_states = gen_out['states']
                gen_words  = gen_out['pred']

                # Get loss and accuracy via generator for comparison
                gen_loss_     = gen_out['loss']
                gen_accuracy_ = gen_out['accuracy']

                fake_states = gen_states
              else:
                #fake_out = generator((input.detach(), input_genres.detach()), states=fake_states)
                fake_out = generator((input, input_genres), states=fake_states)

                fake_states = fake_out['states']
                gen_words   = fake_out['pred']
            
              # Append the generated word to the input sequence (removing the head)
              fake = fake[:, 1:]
              fake = torch.cat((fake, gen_words), 1)
            #=================================================

            
            #===============================================================
            # TRAIN FIRST THE GENERATOR AND THEN THE DISCRIMINATOR

            # Train the GENERATOR
            # We invert the labels here and don't train the discriminator because we want the generator
            # to make things the discriminator classifies as true.

            # Train every 10 batches
            '''if j % 10 == 0:
              # Train the generator
              # We invert the labels here and don't train the discriminator because we want the generator
              # to make things the discriminator classifies as true.
              
              # zero the gradients on each iteration
              gen_optimizer.zero_grad()

              generator_discriminator_out = discriminator(x_fake, target_real)
              gen_loss = generator_discriminator_out['loss']
              gen_loss.backward()
              gen_optimizer.step()'''

            # zero the gradients on each iteration
            '''gen_optimizer.zero_grad()

            #generator_discriminator_out = discriminator(fake, target_real)
            generator_discriminator_out = discriminator((fake, input_genres), target_real)

            gen_loss     = generator_discriminator_out['loss']
            gen_accuracy = generator_discriminator_out['accuracy']

            gen_states = repackage_states(gen_states)

            gen_loss.backward()
            gen_optimizer.step()'''

            # Print input, target and generated words            
            '''print("input words:\t%r" %convert(input))
            print("target word:\t%r" %([get_word_from_index(idx.item()) for idx in target]))
            print("word generated:\t%r" %([get_word_from_index(idx.item()) for idx in gen_out['pred']]))
            print("loss (using discriminator): %.3f" %gen_loss.item())
            print("accuracy (using discriminator): %.3f" %gen_accuracy.item())
            print("+"*40)#'''


            # Train the DISCRIMINATOR on the true/generated data
            '''dis_optimizer.zero_grad()

            #discriminator_out_true = discriminator(input, target_real)
            discriminator_out_true = discriminator((input, input_genres), target_real)
            discriminator_loss_true = discriminator_out_true['loss']
            discriminator_acc_true  = discriminator_out_true['accuracy']

            # Detach the gradients of generated data because we are focusing on the discriminator
            #discriminator_out_fake = discriminator(fake.detach(), target_fake)
            discriminator_out_fake = discriminator((fake.detach(), input_genres), target_fake)
            discriminator_loss_fake = discriminator_out_fake['loss']
            discriminator_acc_fake  = discriminator_out_fake['accuracy']
            
            dis_loss = discriminator_loss_true + discriminator_loss_fake
            #dis_loss = (discriminator_loss_true + discriminator_loss_fake) / 2

            dis_accuracy = (discriminator_acc_true + discriminator_acc_fake) / 2

            dis_loss.backward()
            dis_optimizer.step()'''

            # Print input and generated words
            '''print("input words:\t%r" %convert(input))
            print("discriminator prediction (true):\t%r" %([int(x) for x in discriminator_out_true['pred']]))
            print("fake words:\t%r" %convert(fake))
            print("discriminator prediction (fake):\t%r" %([int(x) for x in discriminator_out_fake['pred']]))
            print("loss (sum of the losses): %.3f" %dis_loss.item())
            print("accuracy (sum of the accuracies): %.3f" %dis_accuracy.item())
            print("loss (true): %.3f" %discriminator_loss_true.item())
            print("loss (fake): %.3f" %discriminator_loss_fake.item())
            print("accuracy (true): %.3f" %discriminator_acc_true.item())
            print("accuracy (fake): %.3f" %discriminator_acc_fake.item())
            print("*"*40)'''
            #===============================================================



            #===============================================================
            # TRAIN FIRST THE DISCRIMINATOR AND THEN THE GENERATOR

            # Train the DISCRIMINATOR on the true/generated data
            dis_optimizer.zero_grad()

            #discriminator_out_true = discriminator(input, target_real)
            discriminator_out_true = discriminator((input, input_genres), target_real)
            discriminator_loss_true = discriminator_out_true['loss']
            discriminator_acc_true  = discriminator_out_true['accuracy']

            # Detach the gradients of generated data because we are focusing on the discriminator
            #discriminator_out_fake = discriminator(fake.detach(), target_fake)
            discriminator_out_fake = discriminator((fake.detach(), input_genres), target_fake)
            discriminator_loss_fake = discriminator_out_fake['loss']
            discriminator_acc_fake  = discriminator_out_fake['accuracy']
            
            dis_loss = discriminator_loss_true + discriminator_loss_fake
            #dis_loss = (discriminator_loss_true + discriminator_loss_fake) / 2

            dis_accuracy = (discriminator_acc_true + discriminator_acc_fake) / 2

            dis_loss.backward()
            dis_optimizer.step()


            # Train the GENERATOR
            # We invert the labels here and don't train the discriminator because we want the generator
            # to make things the discriminator classifies as true.

            # Train every 10 batches
            '''if j % 10 == 0:
              # Train the generator
              # We invert the labels here and don't train the discriminator because we want the generator
              # to make things the discriminator classifies as true.
              
              # zero the gradients on each iteration
              gen_optimizer.zero_grad()

              generator_discriminator_out = discriminator(x_fake, target_real)
              gen_loss = generator_discriminator_out['loss']
              gen_loss.backward()
              gen_optimizer.step()'''

            # zero the gradients on each iteration
            gen_optimizer.zero_grad()

            generator_discriminator_out = discriminator((fake, input_genres), target_real)

            gen_loss     = generator_discriminator_out['loss']
            gen_accuracy = generator_discriminator_out['accuracy']

            gen_states = repackage_states(gen_states)

            gen_loss.backward()
            gen_optimizer.step()
            #===============================================================

            # Track progress
            gen_losses.append(gen_loss)
            dis_losses.append(dis_loss)
            gen_accuracies.append(gen_accuracy)
            dis_accuracies.append(dis_accuracy)


          j += 1

        
        # Print input, target and generated words            
        print("input words:\t%r" %convert(input))
        print("target word:\t%r" %([get_word_from_index(idx.item()) for idx in target]))
        print("word generated:\t%r" %([get_word_from_index(idx.item()) for idx in gen_out['pred']]))
        print("loss (using discriminator): %.3f" %gen_loss.item())
        print("accuracy (using discriminator): %.3f" %gen_accuracy.item())
        print("loss (using generator): %.3f" %gen_loss_.item())
        print("accuracy (using generator): %.3f" %gen_accuracy_.item())
        print("+"*40)#'''

        # Print input and generated words
        print("input words:\t%r" %convert(input))
        print("discriminator prediction (true):\t%r" %([x.item() for x in discriminator_out_true['logits']]))
        print("discriminator prediction (true) rounded:\t%r" %([int(x) for x in discriminator_out_true['pred']]))
        print("fake words:\t%r" %convert(fake))
        print("discriminator prediction (fake):\t%r" %([x.item() for x in discriminator_out_fake['logits']]))
        print("discriminator prediction (fake) rounded:\t%r" %([int(x) for x in discriminator_out_fake['pred']]))
        print("loss (sum of the losses): %.3f" %dis_loss.item())
        print("accuracy (sum of the accuracies): %.3f" %dis_accuracy.item())
        print("loss (true): %.3f" %discriminator_loss_true.item())
        print("loss (fake): %.3f" %discriminator_loss_fake.item())
        print("accuracy (true): %.3f" %discriminator_acc_true.item())
        print("accuracy (fake): %.3f" %discriminator_acc_fake.item())
        print("*"*40)#'''

        gen.global_epoch += 1
        dis.global_epoch += 1

        gen_mean_loss = sum(gen_losses) / len(gen_losses)
        dis_mean_loss = sum(dis_losses) / len(dis_losses)
        gen_mean_accuracy = sum(gen_accuracies) / len(gen_accuracies)
        dis_mean_accuracy = sum(dis_accuracies) / len(dis_accuracies)
        gen_train_history.append(gen_mean_loss.item())
        dis_train_history.append(dis_mean_loss.item())
        if verbose or epoch == epochs - 1:
            print(f'  Epoch {gen.global_epoch:3d} => Generator loss: {gen_mean_loss:0.6f}, Discriminator loss: {dis_mean_loss:0.6f} : Generator accuracy: {gen_mean_accuracy:0.3f}, Discriminator accuracy: {dis_mean_accuracy:0.3f}')
    
        # Save generator model (checkpoint)
        torch.save({
            'epoch': epoch,
            'model_state_dict': gen.state_dict(),
            'optimizer_state_dict': gen_optimizer.state_dict(),
            'loss': gen_mean_loss,
            'accuracy': gen_mean_accuracy
            }, PATH)

    return {
        'gen_train_history': gen_train_history,
        'dis_train_history': dis_train_history
    }

In [None]:
# Train only the generator
'''restore = True
epoch  = 0
epochs = 20

if restore:
  checkpoint = torch.load(PATH)
  gen.load_state_dict(checkpoint['model_state_dict'])
  gen_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  loss     = checkpoint['loss']
  accuracy = checkpoint['accuracy']
  
  #epochs -= epoch
  print("Resume from epoch %d with loss %.3f and accuracy %.3f" %(epoch, loss, accuracy))

# Train the model
logs = train_and_evaluate_gen(
    generator=gen,
    discriminator=dis,
    gen_optimizer=gen_optimizer,
    dis_optimizer=dis_optimizer,
    train_dataloader=train_dataloader,
    start_epoch=epoch,
    epochs=epochs)#'''

'restore = True\nepoch  = 0\nepochs = 20\n\nif restore:\n  checkpoint = torch.load(PATH)\n  gen.load_state_dict(checkpoint[\'model_state_dict\'])\n  gen_optimizer.load_state_dict(checkpoint[\'optimizer_state_dict\'])\n  epoch = checkpoint[\'epoch\']\n  loss     = checkpoint[\'loss\']\n  accuracy = checkpoint[\'accuracy\']\n  \n  #epochs -= epoch\n  print("Resume from epoch %d with loss %.3f and accuracy %.3f" %(epoch, loss, accuracy))\n\n# Train the model\nlogs = train_and_evaluate_gen(\n    generator=gen,\n    discriminator=dis,\n    gen_optimizer=gen_optimizer,\n    dis_optimizer=dis_optimizer,\n    train_dataloader=train_dataloader,\n    start_epoch=epoch,\n    epochs=epochs)#'

In [None]:
# Train only the discriminator
'''restore = False
epoch  = 0
epochs = 20

if restore:
  checkpoint = torch.load(PATH)
  gen.load_state_dict(checkpoint['model_state_dict'])
  gen_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  loss     = checkpoint['loss']
  accuracy = checkpoint['accuracy']
  
  #epochs -= epoch
  print("Resume from epoch %d with loss %.3f and accuracy %.3f" %(epoch, loss, accuracy))

logs = train_and_evaluate_dis(
    generator=gen,
    discriminator=dis,
    gen_optimizer=gen_optimizer,
    dis_optimizer=dis_optimizer,
    train_dataloader=train_dataloader,
    start_epoch=epoch,
    epochs=epochs)#'''

'restore = False\nepoch  = 0\nepochs = 20\n\nif restore:\n  checkpoint = torch.load(PATH)\n  gen.load_state_dict(checkpoint[\'model_state_dict\'])\n  gen_optimizer.load_state_dict(checkpoint[\'optimizer_state_dict\'])\n  epoch = checkpoint[\'epoch\']\n  loss     = checkpoint[\'loss\']\n  accuracy = checkpoint[\'accuracy\']\n  \n  #epochs -= epoch\n  print("Resume from epoch %d with loss %.3f and accuracy %.3f" %(epoch, loss, accuracy))\n\nlogs = train_and_evaluate_dis(\n    generator=gen,\n    discriminator=dis,\n    gen_optimizer=gen_optimizer,\n    dis_optimizer=dis_optimizer,\n    train_dataloader=train_dataloader,\n    start_epoch=epoch,\n    epochs=epochs)#'

In [None]:
# Train the GAN
restore = True
epoch  = 0
epochs = 20

if restore:
  checkpoint = torch.load(PATH)
  gen.load_state_dict(checkpoint['model_state_dict'])
  gen_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  loss     = checkpoint['loss']
  accuracy = checkpoint['accuracy']
  
  #epochs -= epoch
  print("Resume from epoch %d with loss %.3f and accuracy %.3f" %(epoch, loss, accuracy))

logs = train_and_evaluate_GAN(
    generator=gen,
    discriminator=dis,
    gen_optimizer=gen_optimizer,
    dis_optimizer=dis_optimizer,
    train_dataloader=train_dataloader,
    start_epoch=epoch,
    epochs=epochs)#'''

Resume from epoch 2 with loss 13.352 and accuracy 0.000


  0%|          | 0/31 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

Plot loss charts

In [None]:
def plot_logs(logs: Dict, data, title: str):
    plt.figure(figsize=(8,6))

    plt.plot(list(range(len(data))), data, label='Train loss')
    
    plt.title(title)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(loc="upper right")

    plt.show()

plot_logs(logs, logs['dis_train_history'], 'Discriminator loss')

In [None]:
plot_logs(logs, logs['gen_train_history'], 'Generator loss')

Perform a simple discriminator test

In [None]:
# Test the discriminator
def test(text, target):
  # Split text on whitespaces
  text_splitted    = text.split()

  # Calculate the number of sequences in the text
  sequences_number = len(text_splitted) - SEQUENCE_LENGTH

  skipped_seq = 0
  score       = 0

  # Iterate sequences in the text
  for i in range(sequences_number):
      # Get a sequence composed of the current word plus subsequent SEQUENCE_LENGTH words
      seq = text_splitted[i:i + SEQUENCE_LENGTH]
      
      # Convert sequence from words to indices
      converted_seq = []
      for word in seq:
        index = get_index_from_word(word)
        
        # If the word is not in the vocabulary, break the loop
        if index == None:
          break

        converted_seq.append(index)

      # Check if the converted sequence length is less than the SEQUENCE_LENGTH
      if len(converted_seq) < SEQUENCE_LENGTH:
        # Skip this sequence
        skipped_seq += 1
        continue

      print("seq: ", seq)
      print("converted_seq: ", converted_seq)
      
      # Convert to tensor
      tensor = torch.tensor(converted_seq).unsqueeze(0).cuda()

      # Predict the sequence -> 1 true, 0 fake
      out = dis(tensor)
      pred = out["pred"].item()

      print(i, " - Discriminator pred: ", pred, " (expected ", target, ")")
      print("*"*40)

      if pred == target:
        score += 1

  print("\nDiscriminator score (all sequences): ", score, "/", sequences_number)
  print("Skipped sequences: ", skipped_seq)
  print("Discriminator score (without skipped sequences): ", score, "/", sequences_number - skipped_seq)


print("="*15 + " TEST 1 " + "="*15)
print("Use plain text 1")
text = "After school, Kamal took the girls to the old house. It was very old and very dirty too. There was rubbish everywhere. The windows were broken and the walls were damp. It was scary. Amy didn’t like it. There were paintings of zombies and skeletons on the walls. “We’re going to take photos for the school art competition,” said Kamal. Amy didn’t like it but she didn’t say anything. “Where’s Grant?” asked Tara. “Er, he’s buying more paint.” Kamal looked away quickly. Tara thought he looked suspicious. “It’s getting dark, can we go now?” said Amy. She didn’t like zombies."
target = 1.0
print("text: ", text)
test(text, target)
print("="*40)

print("\n" + "="*15 + " TEST 2 " + "="*15)
print("Use text.lower()")
text = text.lower()
print("text: ", text)
test(text, target)
print("="*40)

print("\n" + "="*15 + " TEST 3 " + "="*15)
print("Use text.lower() and remove punctuation")
translations = { ord(','): None, ord('.'): None, ord('“'): None, ord('”'): None, ord('?'): None }
text = text.translate(translations)
print("text: ", text)
test(text, target)
print("="*40)

print("\n" + "="*15 + " TEST 4 " + "="*15)
print("Use text.lower(), remove punctuation and replace ’ with '")
text = text.replace("’", "'")
print("text: ", text)
test(text, target)
print("="*40)

print("\n" + "="*15 + " TEST 5 " + "="*15)
print("Use plain text 2")
text = "Saturday was silent Surely it paperboard debtors lia astoundingly productimage gelding insects specialises dentistry habitus balloons yankee xenon dependant affiliated flab circadian racecourse extravaganza aerials improvementHigher symmetries mallory daubed cranial yarns postsMost spectacles rav mediates nastiest fractions else unused clothingoptional irl boundless maritime hats spycam scamGames elitists chickens homefront subsequent fortnight perfectI salesmen unconsolidated interiors stylised discharge confederation infringer couplers avenues hypermarkets intimidation dainty iguana randoms underskirt laughingstock bootstrapped ox resultsofonline alchemist oriental daze circulation vWords airflow clark airplane rol intense evidently marginata unconstructive busses hexafluoride remanded collegiate buttocks trad gaya areaAdvantagesCountryside tailed untraceable supervisor glaciation ulasanEnregistrer contemptible esta flipping datoer thermogenesis tubs vacating settings"
target = 0.0
print("text: ", text)
test(text, target)
print("="*40)

print("\n" + "="*15 + " TEST 6 " + "="*15)
print("Use text.lower()")
test(text.lower(), target)
print("="*40)

print("\n" + "="*15 + " TEST 7 " + "="*15)
print("Use text.lower() text 3")
text = "Saturday was silent Surely it frothy pacifies bannock conscripting cles flues ruminative anathematized sysinternals insensitivities inpact dline farmstay castigates endnote carpels feasted pioneer qualia heathcare katsuobushi étude suburi tersebut ideastream evinced crashers に vandalous rebaudioside disinvite automap jermaine sharepoint postlink pleeeeease locos supose limner tryal recents sada spilled gossipping redoing ritzy daerah minimart compl potted buis kwami repent peroxidation dramz seedheads uncapping reggie norethindrone particolarmente appeareth syndication luigi cen hematologist lampworked unshipped softs steroidogenesis haziness shahi erogenous zarzuela neuroscience panthenol ruing aseries distemper cscript contemptible queerness mommies quatrain rubbishy speck catbacks dorsey roadgoing tablesetting consid motor vesca unveiling repeatability functi anywhen daisychain afflicted convergence rver"
test(text.lower(), target)
print("="*40)

# Store data to files

Save the dictionaries to convert words to indices and viceversa

In [None]:
#import pickle
import json

FILENAME_W2I = '/content/drive/MyDrive/DM project - NLP lyrics generation/Dictionaries/words2indices'
FILENAME_I2W = '/content/drive/MyDrive/DM project - NLP lyrics generation/Dictionaries/indices2words'

# pickle for saving the dictionary to a file in binary format.
# JSON for saving the dictionary to a file in a human-readable format
def save_dictionary(filename, data):
  #with open(filename + ".pkl", "wb") as f:
  #  pickle.dump(data, f)

  with open(filename + ".json", "w") as f:
    json.dump(data, f)

print(indices2words[0])
print(indices2words[1])
print(indices2words[2])


save_dictionary(FILENAME_W2I, words2indices)
save_dictionary(FILENAME_I2W, indices2words)

Save one hot encoding dictionary for the genre

In [None]:
FILENAME = '/content/drive/MyDrive/DM project - NLP lyrics generation/Dictionaries/one_hot_encoding_genres'

save_dictionary(FILENAME, one_hot_encoding_genres)

Save word vectors (word2vec embedding)

In [None]:
FILENAME = '/content/drive/MyDrive/DM project - NLP lyrics generation/Dictionaries/word_vectors.pt'

torch.save(word_vectors, FILENAME)

Generate lyrics

In [None]:
# Generate text
TEXT_LENGTH = 100                # Truncate the text when the goal text length has been generated (hard truncation)
LINES = random.randrange(10, 50) # Truncate the text when the goal lines number has been generated (soft truncation)

word  = "The"
genre = "Pop"

states = None
text = ""
prev_word = ""
lines           = 0
generated_words = 0

word2capitalize = ["I", "I'm", "I'd"]
punctuation_subset = { '.', ',', ';', ':', '!', '?', ')', ']', '}', '$', '/', '…', '...', '..' }

input_words = word.strip().split()

# Iterate input words
for i in range(len(input_words)):
  w = input_words[i]

  # Check if the word is not present in the vocabulary in the current form
  if w not in words2indices:
    # Use the lowercase version (as it must be present in one of the two forms)
    input_words[i] = w.lower()

  # Check if this is the first word
  if i == 0:
    # Capitalize the first letter of the word
    w = w[0].upper() + w[1:]
    text = w
  else:
    text += ' ' + w

  prev_word = w

# One hot encode the genre
input_genre = one_hot_encoding_genres[genre]
input_genre = torch.tensor(input_genre, device=device).unsqueeze(0)

print("Word:", word)
print("Genre:", genre)
#print("Input word: ", input_words)
#print("Input genre: ", input_genre)


def generate_next_word(input_words, states=None):
  # Convert words to indices
  indices = [get_index_from_word(w) for w in input_words]
  indices = torch.tensor(indices, device=device).unsqueeze(0)

  y = gen((indices, input_genre), states=states)

  next_word_index = y['pred'].item()
  #print("next_word_index:", next_word_index)

  return get_word_from_index(next_word_index), y['states']


#for i in range(TEXT_LENGTH):
while lines < LINES:
  # Generate next word
  next_word, states = generate_next_word(input_words, states)
  
  # Append at the end removing the head
  input_words = input_words[1:]
  input_words.append(next_word)

  #print("next word:", next_word)

  # Check if next word must be capitalized in the output text
  for word in word2capitalize:
    if next_word == word.lower():
      # Replace the generated word with the capitalized version
      next_word = word
      break

  # Check if previous word is newline (i.e. the generated word belongs to a new line) or a dot
  if prev_word == '\n' or prev_word == '.':
    # Capitalize the first letter of the generated word
    next_word = next_word[0].upper() + next_word[1:]
  
  # Check if previous word is newline or a parenthesis or next word is newline or punctuation
  if prev_word == '\n' or prev_word == '(' or next_word == '\n' or next_word in punctuation_subset:
    if next_word == '\n':
      # Update generated lines
      lines += 1

      # Check if the number of lines has been achieved
      if lines == LINES:
        break

    # Add the generated word to the output text without prepending a space
    text += next_word

  else:
    # Add the generated word to the output text prepending a space
    text += ' ' + next_word

  prev_word = next_word
  generated_words += 1


print("\nlines:", LINES)
print("generated words:", generated_words)
print("\nLyrics:")
print(text)

# WGAN

Discriminator is no longer a binary classifier but a critic

In [None]:
class Discriminator(nn.Module):
  
  def __init__(
    self,
		word_vectors: torch.Tensor,
		lstm_hidden_size: int,
    dense_size: int#,
		#hidden_size: int
	):
    super().__init__()
    
    # Embedding layer
    self.embedding = torch.nn.Embedding.from_pretrained(word_vectors)
    
    # Recurrent layer (LSTM)
    #self.rnn = torch.nn.LSTM(input_size=word_vectors.size(1), hidden_size=lstm_hidden_size, num_layers=1, batch_first=True, bidirectional=True)
    self.rnn = torch.nn.LSTM(input_size=word_vectors.size(1), hidden_size=lstm_hidden_size, num_layers=3, batch_first=True)

    # Dense layer
    self.dense = torch.nn.Linear(dense_size, 1)
    torch.nn.init.uniform_(self.dense.weight)
    
    # Dropout function
    self.dropout = nn.Dropout(p=0.1)
    
    self.global_epoch = 0
    
  def forward(self, x, y=None):
    # Split input in lyrics and genre
    lyrics = x[0]
    genres = x[1]
    
    # Embedding words from indices
    out = self.embedding(lyrics)
    #print("out: ", out.size())

    # Recurrent layer
    out = self.rnn(embedding_out)[0]
    #print("out rnn: ", out.size())
    
    # Duplicate the genre vector associated to a sequence for each word in the sequence
    seq_length = lyrics.size()[1]

    if seq_length > 1:
      genres_duplicated = []
      for tensor in genres:
        duplicated = [list(tensor) for i in range(seq_length)]
        genres_duplicated.append(duplicated)

      genres = torch.tensor(genres_duplicated, device=device)
    else:
      # Just increment the genres vector dimension
      genres = genres.unsqueeze(0)
    
    # Concatenate the LSTM output with the encoding of genres
    out = torch.cat((out, genres), dim=-1)
    #print("out:")
    #print(out)
    #print("out size: ", out.size())

    # Dense layer
    out = self.dense(out)

    # Use the last prediction
    out = out[:,-1,:]
    #print("dense_out (last pred): ", out)
    #print("dense_out size (last pred): ", out.size())

    logits = out
    
    pred = torch.round(logits)
    #print("dense_out after round (last pred): ", pred)
    #print("dense_out after round size (last pred): ", pred.size())
    
    result = {'logits': logits, 'pred': pred}
    
    # compute loss
    if y is not None:
      y = y.unsqueeze(-1)
      result['loss']     = self.loss(logits, y)
      result['accuracy'] = self.accuracy(pred, y)
      
    return result

  def loss(self, pred, target):
    return torch.mean(pred)

  def accuracy(self, pred, target):
    return torch.sum(pred == target) / pred.size()[0]

Instantiate the models and the optimizers

In [None]:
############ Parameters: #############

GEN_LSTM_SIZE   = 256
GEN_DENSE_SIZE  = 256 + NUMBER_GENRES
GEN_HIDDEN_SIZE = 2048

DIS_LSTM_SIZE   = 256
DIS_DENSE_SIZE  = 256 + NUMBER_GENRES
DIS_HIDDEN_SIZE = 64

# NO GAN
#LR = 0.01

# GAN
# Learning rate for optimizers
LR = 0.0002

# Betas hyperparam for Adam optimizers
BETAS = (0.5, 0.999)

######################################

gen = Generator(
    word_vectors,
    lstm_hidden_size=GEN_LSTM_SIZE,
    dense_size=GEN_DENSE_SIZE, 
    #hidden_size=GEN_HIDDEN_SIZE,
    vocab_size=len(word_vectors))

dis = Discriminator(
    word_vectors,
    lstm_hidden_size=DIS_LSTM_SIZE,
    dense_size=DIS_DENSE_SIZE)#,
    #hidden_size=DIS_HIDDEN_SIZE)

gen_optimizer = torch.optim.Adam(gen.parameters(), lr=LR, betas=BETAS)
dis_optimizer = torch.optim.Adam(dis.parameters(), lr=LR, betas=BETAS)

# Try to move the models on the GPU 
if torch.cuda.is_available():
  gen.cuda()
  dis.cuda()

# Print the models summaries
print(gen)
print(dis)

Generator(
  (embedding): Embedding(17626, 300)
  (rnn): LSTM(300, 256, batch_first=True)
  (lin1): Linear(in_features=266, out_features=17626, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (loss): CrossEntropyLoss()
)
Discriminator(
  (embedding): Embedding(17626, 300)
  (rnn): LSTM(300, 256, num_layers=3, batch_first=True)
  (linear_one): Linear(in_features=266, out_features=1, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


WGAN training loop

In [None]:
from torch import autograd

PATH = '/content/drive/MyDrive/DM project - NLP lyrics generation/generator_model_WGAN.pt'

GP_WEIGHT = 10.0

# Wrap hidden states in new Tensors, to detach them from their history
def repackage_states(states):
  if isinstance(states, torch.Tensor):
      return states.detach()
  else:
      return tuple(repackage_states(item) for item in states)

# Convert input sequences of indices in sequences of words
def convert(input):
  in_sequences = []
  for seq in input:
    words = []
    for idx in seq:
      word = get_word_from_index(idx.item())
      words.append(word)
    in_sequences.append(words)
  return in_sequences


def train_and_evaluate_GAN(
    generator:     nn.Module,
    discriminator: nn.Module,
    gen_optimizer: torch.optim.Optimizer,
    dis_optimizer: torch.optim.Optimizer,
    train_dataloader: DataLoader,
    start_epoch: int = 0, 
    epochs: int = 5,
    verbose: bool = True,
):
    gen_train_history = []
    dis_train_history = []

    # Iterate epochs starting from start_epoch
    for epoch in range(start_epoch, epochs):
        gen_losses = []
        dis_losses = []
        gen_accuracies = []
        dis_accuracies = []

        j = 0

        # Iterate batches of the training set
        for x, y in tqdm(train_dataloader):
          #print("x: ", x)
          #print("y: ", y)
          #print("\nx size: ", x.size())
          #print("y size: ", y.size())

          #print("\nx[0] size: ", x[0].size())
          #print("\nx[1] size: ", x[1].size())

          # Get the number of sequences in each batch
          input_sequences = x[0]
          input_genres    = x[1]
          seq_numbers = input_sequences.size()[1]

          # Init generator states for the LSTM
          gen_states = None

          # Iterate sequences and targets of the batch
          for i in range(seq_numbers):
            # Get the i-th sequence (column) and the i-th target (column)
            input  = input_sequences[:, i]
            target = y[:, i]

            #print("input: ", input)
            #print("target: ", target)
            #print("input size: ", input.size())

            # Get input batch size
            input_batch_size = input_sequences.size()[0]

            # Init real and fake targets for discriminator
            target_real = torch.ones([input_batch_size], device=device)
            target_fake = torch.zeros([input_batch_size], device=device)


            #=================================================
            # GENERATE A FAKE SEQUENCE

            fake = input
            fake_states = None

            # SEQUENCE_LENGTH words generation through moving window
            for i in range(SEQUENCE_LENGTH):
              if i == 0:
                gen_out = generator((input, input_genres), target, gen_states)

                gen_states = gen_out['states']
                gen_words  = gen_out['pred']

                # Get loss and accuracy via generator for comparison
                gen_loss_     = gen_out['loss']
                gen_accuracy_ = gen_out['accuracy']

                fake_states = gen_states
              else:
                #fake_out = generator((input.detach(), input_genres.detach()), states=fake_states)
                fake_out = generator((input, input_genres), states=fake_states)

                fake_states = fake_out['states']
                gen_words   = fake_out['pred']
            
              # Append the generated word to the input sequence (removing the head)
              fake = fake[:, 1:]
              fake = torch.cat((fake, gen_words), 1)
            #=================================================


            #===============================================================
            # TRAIN FIRST THE DISCRIMINATOR AND THEN THE GENERATOR

            # Gradient penalty doesn't work with tensor of indices, 
            # so we can adapt the neural networks to work with word2vec vectors (that are more similar to pixel matrices)
            def gradient_penalty(true_data, fake_data, batch_size, input_genres):
              # Calculate interpolated data
              alpha = torch.randn(batch_size, SEQUENCE_LENGTH, device=device)
              interpolated = true_data + alpha * (fake_data - true_data)

              #print("alpha:")
              #print(alpha)
              #print("interpolated PRE:")
              #print(interpolated)

              # Cast to Long before giving it in input to the discriminator (we need indices for embedding)
              interpolated = interpolated.type(torch.LongTensor).cuda()
              
              #print("interpolated:")
              #print(interpolated)

              # Replace negative or out of vocabulary indices with PAD
              #interpolated[interpolated < 0 or interpolated >= len(word_vectors)] = 0
              interpolated[interpolated < 0] = 0
              interpolated[interpolated >= len(word_vectors)] = 0

              print("interpolated:")
              print(interpolated)


              # Create autograd variable for require gradients (not available for LongTensor)
              interpolated = autograd.Variable(interpolated, requires_grad=True)

              print("interpolated:")
              print(interpolated)


              dis_optimizer.zero_grad()
              dis_out = discriminator((interpolated, input_genres), target_real)
              
              # Calculate the gradients
              '''dis_loss = dis_out['loss']
              dis_loss.backward()
              gradients = interpolated.grad[0]

              # Calculate the norm of the gradients
              norm = torch.sqrt( torch.sum(torch.square(gradients), dim=1, keepdim=True) )

              # Calculate the gradient penalty
              gradient_penalty = torch.mean((norm - 1.0) ** 2)

              return gradient_penalty'''


              gradients = autograd.grad(
                outputs=dis_out, inputs=interpolated,
                grad_outputs=torch.ones(disc_interpolates.size(), device=device),
                create_graph=True, retain_graph=True, only_inputs=True)[0]

              print("gradients:")
              print(gradients)

              # Calculate the gradient penalty
              gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

              print("gradient_penalty:")
              print(gradient_penalty)
              return gradient_penalty








            # Train the DISCRIMINATOR on the true/generated data
            dis_optimizer.zero_grad()

            discriminator_out_true = discriminator((input, input_genres), target_real)
            discriminator_loss_true = discriminator_out_true['loss']
            discriminator_acc_true  = discriminator_out_true['accuracy']

            # Detach the gradients of generated data because we are focusing on the discriminator
            discriminator_out_fake = discriminator((fake.detach(), input_genres), target_fake)
            discriminator_loss_fake = discriminator_out_fake['loss']
            discriminator_acc_fake  = discriminator_out_fake['accuracy']
            
            dis_loss = discriminator_loss_fake - discriminator_loss_true
            dis_accuracy = (discriminator_acc_true + discriminator_acc_fake) / 2

            # Add gradient penalty
            #dis_loss += gradient_penalty(input, fake.detach(), input_batch_size, input_genres) * GP_WEIGHT 

            dis_loss.backward()
            dis_optimizer.step()


            # Train the GENERATOR
            # We invert the labels here and don't train the discriminator because we want the generator
            # to make things the discriminator classifies as true.

            # Zero the gradients on each iteration
            gen_optimizer.zero_grad()

            generator_discriminator_out = discriminator((fake, input_genres), target_real)

            gen_loss     = -generator_discriminator_out['loss']
            gen_accuracy = generator_discriminator_out['accuracy']

            gen_states = repackage_states(gen_states)

            gen_loss.backward()
            gen_optimizer.step()
            #===============================================================
            

            # Track progress
            gen_losses.append(gen_loss)
            dis_losses.append(dis_loss)
            gen_accuracies.append(gen_accuracy)
            dis_accuracies.append(dis_accuracy)


          j += 1


        gen.global_epoch += 1
        dis.global_epoch += 1

        gen_mean_loss = sum(gen_losses) / len(gen_losses)
        dis_mean_loss = sum(dis_losses) / len(dis_losses)
        gen_mean_accuracy = sum(gen_accuracies) / len(gen_accuracies)
        dis_mean_accuracy = sum(dis_accuracies) / len(dis_accuracies)
        gen_train_history.append(gen_mean_loss.item())
        dis_train_history.append(dis_mean_loss.item())
        if verbose or epoch == epochs - 1:
            print(f'  Epoch {gen.global_epoch:3d} => Generator loss: {gen_mean_loss:0.6f}, Discriminator loss: {dis_mean_loss:0.6f} : Generator accuracy: {gen_mean_accuracy:0.3f}, Discriminator accuracy: {dis_mean_accuracy:0.3f}')
    

        # Save generator model to restore it in case of disconnection (checkpoint)
        torch.save({
            'epoch': epoch,
            'model_state_dict': gen.state_dict(),
            'optimizer_state_dict': gen_optimizer.state_dict(),
            'loss': gen_mean_loss,
            'accuracy': gen_mean_accuracy
            }, PATH)

        # PRINT AT THE END OF A EPOCH
        # Print input, target and generated words            
        print("input words:\t%r" %convert(input))
        print("target word:\t%r" %([get_word_from_index(idx.item()) for idx in target]))
        print("word generated:\t%r" %([get_word_from_index(idx.item()) for idx in gen_out['pred']]))
        print("loss (using discriminator): %.3f" %gen_loss.item())
        print("accuracy (using discriminator): %.3f" %gen_accuracy.item())
        print("loss (using generator): %.3f" %gen_loss_.item())
        print("accuracy (using generator): %.3f" %gen_accuracy_.item())
        print("+"*40)#'''

        # Print input and generated words
        print("input words:\t%r" %convert(input))
        print("discriminator prediction (true):\t%r" %([x.item() for x in discriminator_out_true['logits']]))
        print("discriminator prediction (true) rounded:\t%r" %([int(x) for x in discriminator_out_true['pred']]))
        print("fake words:\t%r" %convert(fake))
        print("discriminator prediction (fake):\t%r" %([x.item() for x in discriminator_out_fake['logits']]))
        print("discriminator prediction (fake) rounded:\t%r" %([int(x) for x in discriminator_out_fake['pred']]))
        print("loss (sum of the losses): %.3f" %dis_loss.item())
        print("accuracy (sum of the accuracies): %.3f" %dis_accuracy.item())
        print("loss (true): %.3f" %discriminator_loss_true.item())
        print("loss (fake): %.3f" %discriminator_loss_fake.item())
        print("accuracy (true): %.3f" %discriminator_acc_true.item())
        print("accuracy (fake): %.3f" %discriminator_acc_fake.item())
        print("*"*40)#'''

    return {
        'gen_train_history': gen_train_history,
        'dis_train_history': dis_train_history
    }

In [None]:
# Train the GAN
restore = False
epoch  = 0
epochs = 20

if restore:
  checkpoint = torch.load(PATH)
  gen.load_state_dict(checkpoint['model_state_dict'])
  gen_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  loss     = checkpoint['loss']
  accuracy = checkpoint['accuracy']
  
  #epochs -= epoch
  print("Resume from epoch %d with loss %.3f and accuracy %.3f" %(epoch, loss, accuracy))

logs = train_and_evaluate_GAN(
    generator=gen,
    discriminator=dis,
    gen_optimizer=gen_optimizer,
    dis_optimizer=dis_optimizer,
    train_dataloader=train_dataloader,
    start_epoch=epoch,
    epochs=epochs)#'''

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   1 => Generator loss: 151.148727, Discriminator loss: -308.927460 : Generator accuracy: 0.007, Discriminator accuracy: 0.007
input words:	[['we', 'crawl', '\n', 'get', 'me', 'to'], ['nature', 'cheating', 'us', '\n', 'and', 'everything'], ['you', 'know', 'it', 'happens', 'every', 'day'], ['eyes', 'tightly', '\n', 'it', 'will', 'rain'], ['mobbing', '\n', 'deep', 'in', 'queens', 'with'], ['cane', '\n', 'gave', 'black', 'people', 'money'], ['will', 'become', 'lines', '\n', 'cut', "'em"], ['\n', 'many', 'people', ',', 'they', 'pay'], ["don't", 'wanna', 'fight', '\n', "don't", 'want'], ['i', "don't", 'like', 'your', 'life', 'and'], ['and', 'lady', 'liberty', 'moved', 'to', 'san'], ['you', 'need', '\n', 'i', 'come', 'before'], ['dear', '\n', 'my', 'anger', 'is', 'raging'], ['in', 'all', 'black', '\n', 'with', 'a'], ['i', 'have', 'tried', 'everything', 'lord', ','], ['\n', 'put', "'em", 'tings', 'in', 'the'], ['filled', 'with', 'darkness', '\n', 'so', 'if'], ['\n', 'baby', ',', 'i', '

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   2 => Generator loss: 226.440979, Discriminator loss: -454.915863 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['i', 'wish', 'i', 'could', 'spend', 'more'], ['legend', 'seeps', '\n', 'repercussions', 'of', 'my'], ['door', "you'd", 'hear', 'miss', 'jenny', 'crying'], ['will', 'roll', 'the', 'truck', 'out', 'of'], ['?', '\n', 'forgive', 'them', ',', 'my'], ['trust', 'nobody', '\n', 'greetings', 'comrades', '\n'], ['the', 'same', 'time', '\n', 'oh', 'what'], ['and', 'all', 'i', 'want', '\n', 'is'], ['and', 'shook', 'his', 'head', 'and', 'said'], ['are', 'no', 'games', '\n', 'to', 'only'], ['to', 'them', 'and', 'hope', 'they', 'tend'], ['net', 'fly', '\n', 'and', 'after', 'the'], ['absolutely', 'quiet', '\n', 'look', 'up', 'here'], ['you', 'swear', 'you', 'only', 'kissed', 'him'], ['like', 'to', 'go', 'home', 'and', 'go'], ['finally', 'made', 'it', 'across', 'the', 'great'], ['die', '\n', 'i', 'lived', 'good', '\n'], ['and', 'i', 'fell', 'every', 'night

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   3 => Generator loss: 289.434631, Discriminator loss: -581.399414 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['\n', 'see', 'me', 'flying', 'down', 'the'], ['shit', 'for', 'your', 'ass', '(', 'hmm'], ['feels', 'the', 'worst', '\n', 'the', 'bad'], ['echoes', 'in', 'my', 'head', '\n', 'and'], ['park', 'benches', '\n', 'eating', 'little', 'packets'], ['i', 'have', 'buried', 'deep', 'inside', '\n'], ['come', ',', 'the', 'time', 'to', 'run'], [',', 'and', "i'm", 'from', 'lebanon', ','], ['to', 'keep', 'cool', '\n', "it's", 'cliché'], ['into', 'the', 'rest', 'of', 'my', 'life'], ['meanest', 'woman', 'i', 'have', 'most', 'ever'], ['me', '\n', 'with', 'you', 'looking', 'so'], ['electrocute', '\n', 'you', 'will', 'never', 'be'], ['in', 'cause', '\n', 'this', "one's", 'for'], ['like', 'a', 'crime', '\n', 'cause', 'i'], ['to', 'the', 'murderers', 'we', 'have', 'loved'], ['with', 'new', 'city', 'blues', '?', '\n'], ['flipping', 'words', 'like', 'birds', '\n', 

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   4 => Generator loss: 351.501343, Discriminator loss: -707.126587 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['dem', 'no', 'friend', '\n', 'dem', 'no'], [',', 'you', 'will', '\n', 'that', 'you'], ['word', 'from', 'the', 'front', 'never', 'came'], ["it's", 'a', 'blessing', ',', 'for', 'bone'], ['be', 'no', 'wars', 'around', '\n', 'well'], ['than', 'number', 'one', '?', '\n', 'you'], ['\n', 'picking', 'me', 'a', 'bouquet', 'of'], ['i', 'forgot', 'how', 'to', 'kiss', 'and'], ['low', 'but', 'you', 'hear', 'the', 'bass'], ['want', 'to', 'capture', 'one', 'at', 'a'], ['abu', 'dhabi', ',', 'she', 'could', 'be'], ['you', 'have', 'got', 'those', 'things', '\n'], ['it', '\n', 'i', 'know', 'it', 'has'], ['till', 'the', 'judgment', 'day', '\n', 'lord'], ['like', 'to', 'go', 'home', 'and', 'go'], ['in', 'the', 'world', 'she', 'meet', '\n'], ['i', 'see', 'al', 'million', 'lights', '\n'], ['like', 'a', 'crime', '\n', 'cause', 'i'], ['my', 'intimacy', '\n', 'but'

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   5 => Generator loss: 413.315552, Discriminator loss: -832.660950 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['\n', "satan's", 'finest', '\n', 'well', ','], ['want', 'to', 'capture', 'one', 'at', 'a'], ['\n', 'go', 'on', ',', 'go', 'on'], ['say', '\n', 'chorus', ':', '\n', 'see'], ['it', 'lifts', 'me', 'up', ',', 'it'], ['i', 'seem', 'so', 'stressed', ',', 'i'], ['\n', 'our', 'fates', 'compounding', '\n', 'our'], ['\n', 'some', 'people', 'tell', 'you', 'what'], ['like', 'a', 'crime', '\n', 'cause', 'i'], ['\n', 'your', 'heart', 'is', 'glowing', '\n'], ['are', 'the', 'omega', '\n', 'showering', 'the'], ['will', 'never', 'find', 'us', '\n', 'listening'], ['thinking', 'on', 'our', 'sunny', 'days', '\n'], ['in', 'the', 'air', 'like', "it's", 'good'], ['but', 'i', "don't", 'worry', 'much', 'about'], ['still', 'got', 'wet', '\n', 'but', 'that'], ['she', 'just', 'makes', 'me', 'love', 'her'], ['but', 'i', "won't", 'sleep', "'till", 'i'], ['to', 'say', 'g

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   6 => Generator loss: 474.971313, Discriminator loss: -957.269531 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['\n', 'and', 'broke', 'it', 'apart', '\n'], ['prosecco', '\n', 'painting', 'pictures', 'like', "i'm"], ['that', 'cannot', 'be', 'filled', '\n', 'god'], ['a', 'lightning', 'crash', '\n', "i'm", 'gonna'], [',', 'shake', 'summing', '\n', 'you', 'know'], ['hug', 'the', 'law', '\n', 'arms', 'like'], ["i'm", 'a', 'refugee', ',', 'and', 'i'], ['hand', '\n', 'she', 'had', 'to', 'face'], ['of', 'the', "wind's", 'soft', 'lullaby', '\n'], ['\n', 'da', 'whine', 'deh', 'a', 'one'], ['\n', 'song', 'of', 'masses', ',', 'passing'], ['?', '\n', 'were', 'any', 'more', 'complicit'], ['!', ')', '\n', 'one', 'kind', 'for'], [',', 'two', ',', 'three', ',', 'four'], [',', "it's", 'not', 'dark', 'up', 'here'], ['this', '\n', 'something', 'about', 'you', ','], ['gods', 'were', 'young', 'none', 'ruled', 'under'], ['with', 'my', 'hands', 'in', 'the', 'air'], ['curta

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   7 => Generator loss: 536.523804, Discriminator loss: -1081.925659 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['but', 'i', 'want', 'to', 'give', 'it'], ["taxi's", 'waiting', ',', "he's", 'blowing', 'his'], ['\n', 'too', 'young', 'or', 'too', 'dumb'], [',', "haven't", 'found', 'him', 'yet', '\n'], ['oh', ',', "haven't", 'you', 'seen', 'they'], ['colors', ',', 'too', '\n', '\n', 'i'], ["i'm", 'a', 'refugee', ',', 'and', 'i'], ['\n', 'abdul', 'jabbar', ',', 'the', 'coolest'], ['you', 'are', 'ever', 'sad', 'i', 'will'], ['\n', 'black', 'or', 'white', '\n', 'day'], ['sheet', 'music', 'score', '\n', '\n', 'come'], ['can', 'never', 'move', 'backwards', '\n', 'hit'], ['jealousy', '\n', 'woah', ',', 'i', 'just'], ['tell', 'her', 'the', 'truth', 'before', 'she'], ['the', 'mercy', 'that', 'you', 'love', 'to'], ['tanny', 'bitch', '\n', 'who', 'wish', 'to'], ['away', ',', 'son', ',', 'the', 'worst'], ["i'm", 'old', '\n', 'my', 'bones', "don't"], ['i', 'just', 

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   8 => Generator loss: 597.935608, Discriminator loss: -1206.139771 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['hits', 'like', 'blues', 'clues', '\n', 'we'], ['carpet', 'and', 'a', 'jackie', 'chan', 'spear'], ['my', 'arms', 'tonight', '\n', 'noel', ','], ['!', '!', '\n', 'burnin', 'you', '!'], ['sweet', 'old', 'time', 'rock', '&', 'roll'], ['sadness', 'to', 'love', '\n', 'knowing', 'one'], ['\n', 'but', 'then', "i'm", 'the', 'only'], ['never', 'see', 'you', 'again', '\n', 'we'], ['of', 'an', 'evil', 'you', 'hold', '\n'], ['eyes', 'meet', '?', '\n', 'when', 'can'], ['make', 'up', 'new', 'excuses', '\n', 'faking'], ['the', 'weatherman', 'for', 'all', 'of', 'your'], ['you', 'are', 'on', 'solid', 'ground', '\n'], ['beside', 'you', ',', 'lost', 'in', 'the'], ['york', '.', 'escape', 'from', 'la', '.'], ['all', 'roads', 'no', 'longer', 'lead', 'to'], ['you', 'know', 'a', 'place', 'we', 'can'], ['\n', 'where', 'did', 'you', 'come', 'from'], ['in', '\n', '

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch   9 => Generator loss: 659.687927, Discriminator loss: -1331.133911 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['it', 'all', 'day', ',', 'a', 'weed'], ['spirit', 'he', 'wants', 'to', 'kill', '\n'], ['carries', 'up', 'and', 'down', 'the', 'halls'], ['world', 'green', '\n', 'hoping', 'soon', 'to'], ['were', 'unforgettable', '\n', 'you', 'were', 'everything'], ['the', 'way', 'she', 'moves', '\n', 'alright'], ['direction', '.', 'food', 'lies', 'on', 'the'], ['many', 'times', '\n', 'and', 'the', 'thought'], ['will', 'become', 'lines', '\n', 'cut', "'em"], ['the', 'bank', 'open', '\n', '\n', 'yeah'], ['i', 'see', 'al', 'million', 'lights', '\n'], ['flirting', "it's", 'the', 'clothing', '\n', 'taking'], ['\n', '\n', "let's", 'take', 'it', 'easy'], ['kid', 'up', 'in', 'the', 'class', 'made'], ['we', 'crawl', '\n', 'get', 'me', 'to'], ['oh', 'oh', ',', 'oh', 'oh', ','], ['on', 'd', 'ooh', 'ooh', 'ooh', '\n'], ['as', 'much', 'as', 'i', 'love', 'her'], ['alon

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  10 => Generator loss: 721.364685, Discriminator loss: -1455.306885 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['savoury', 'melody', ',', "that's", 'what', 'i'], ['\n', 'well', 'packaged', 'to', 'distribute', 'to'], ['it', 'way', 'too', 'much', '?', '\n'], ['gonna', 'get', 'me', 'some', '\n', '\n'], ['we', 'was', 'over', 'at', 'the', 'church'], ['why', ',', 'but', 'all', 'the', 'real'], ['how', 'can', 'you', 'win', 'some', '?'], ['in', 'my', 'backdoor', 'someday', '\n', 'said'], ['\n', 'song', 'of', 'masses', ',', 'passing'], ['wasting', 'no', 'more', 'time', '\n', 'though'], ['\n', 'where', 'my', 'heart', 'becomes', 'free'], ['me', 'dash', "'way", 'mi', 'pride', '\n'], ['believe', 'a', 'word', ',', 'they', 'will'], ['tried', '\n', 'i', 'have', 'seen', 'a'], ['park', '\n', 'all', 'alone', ',', 'early'], ['i', 'hang', 'onto', 'your', 'pictures', '\n'], ['the', 'slang', '\n', 'we', 'gang', 'bang'], ['\n', 'because', "i'm", 'never', 'gonna', 'let'], [

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  11 => Generator loss: 782.946716, Discriminator loss: -1580.089233 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['floating', 'through', 'my', 'veins', '\n', 'i'], ['they', "don't", 'love', 'you', 'like', 'i'], ['like', 'a', 'truth', '\n', '\n', 'i'], ['jesus', 'forgive', 'me', '\n', 'get', 'your'], ['\n', "don't", 'know', 'if', 'you', 'are'], ['fishing', '\n', 'and', 'catch', 'the', 'sunset'], ['are', '\n', 'buying', 'their', 'cocain', '\n'], ['own', '\n', 'is', 'gone', '\n', 'they'], ['who', 'walk', 'among', 'us', '\n', 'your'], ['look', 'around', 'what', 'happens', ',', 'the'], ['town', 'town', ',', "i'm", 'coming', 'down'], ['in', 'the', 'eye', 'perfume', '\n', 'i'], ['was', 'get', 'in', 'my', 'way', '\n'], ['you', 'could', 'wrap', 'me', 'in', 'a'], ['without', 'sleep', ',', 'weeks', 'and', "won't"], ['find', 'the', 'words', 'to', 'say', '\n'], ['if', 'i', 'should', 'love', 'again', '\n'], ["i'm", 'just', 'trying', 'to', 'make', 'you'], ['our', '

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  12 => Generator loss: 844.559937, Discriminator loss: -1704.513672 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['\n', 'woo', 'woo', '!', '\n', 'keeps'], ['jesus', 'forgive', 'me', '\n', 'get', 'your'], ['home', ',', 'back', 'on', 'your', 'own'], ['them', 'mostly', 'all', 'the', 'time', '\n'], ['timeless', 'classic', '\n', 'every', 'time', 'we'], ['palace', 'has', 'a', 'moss', 'problem', '\n'], ['\n', '\n', 'ancestors', 'so', 'grotesque', '\n'], ['are', 'two', 'trains', 'on', 'different', 'tracks'], ['\n', 'well', 'packaged', 'to', 'distribute', 'to'], ['the', 'hands', 'that', 'hold', 'you', '\n'], ['there', 'will', 'be', 'cars', 'that', "won't"], ['guessing', 'why', 'i', 'never', 'could', '\n'], ['people', 'ten', 'feet', 'tall', '\n', 'on'], ["don't", 'know', 'if', 'i', 'do', '\n'], ['to', 'go', '\n', "it's", 'because', 'i'], ['i', 'pray', '?', 'should', 'i', 'die'], ['every', 'once', 'in', 'a', 'while', '\n'], ['\n', 'you', "can't", 'hang', 'with'

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  13 => Generator loss: 906.030823, Discriminator loss: -1828.866455 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['i', 'want', 'you', '\n', 'dirty', 'love'], ['\n', 'where', 'did', 'you', 'get', 'that'], ['snow', 'is', 'cold', ',', 'rain', 'is'], ['face', 'on', 'the', 'water', '\n', 'wrinkles'], ['it', 'would', 'float', 'until', 'it', 'reached'], ['mythology', '\n', 'we', 'are', 'called', 'to'], ["'bout", 'to', 'go', 'on', 'live', '\n'], ['\n', 'wear', 'a', 'smile', '\n', 'dig'], ["i'm", 'breathing', 'in', '\n', "i'm", 'breathing'], ['so', 'cause', 'i', 'still', '\n', 'have'], ['the', "end's", 'in', 'sight', '\n', 'we'], ['to', 'repeat', '\n', 'carbon', 'faction', 'brings'], ['what', 'to', 'do', '\n', 'give', 'my'], ['skin', 'and', 'different', 'religion', '\n', 'i'], ['of', 'our', 'nature', ',', 'ancient', 'future'], ['outside', '\n', 'burning', 'on', 'the', 'inside'], ['it', 'all', '?', '\n', 'what', 'to'], [',', 'go', 'a', 'foreign', 'pon', 'tour'

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  14 => Generator loss: 967.745361, Discriminator loss: -1953.200195 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['tanny', 'bitch', '\n', 'who', 'wish', 'to'], ['on', 'me', '\n', 'hey', 'lady', "won't"], ['\n', '\n', 'fuck', 'me', 'up', '\n'], ['\n', 'where', 'did', 'you', 'get', 'that'], ['spirit', 'he', 'wants', 'to', 'kill', '\n'], ['eyes', 'tightly', '\n', 'it', 'will', 'rain'], ['to', 'ancient', 'trees', ',', 'shall', 'flower'], ['you', 'have', 'got', 'those', 'things', '\n'], ['a', "moment's", 'pleasure', '?', '\n', 'can'], ['up', ',', 'get', 'down', '\n', 'get'], ['for', 'a', 'happy', 'ass', 'bitch', '\n'], ['go', '\n', 'i', 'gotta', 'figure', 'this'], ['and', "i'm", 'afraid', 'and', 'alone', 'with'], ['\n', 'the', 'bonniest', 'lad', 'in', 'all'], ['call', 'you', 'up', ',', 'invest', 'a'], ['dancing', 'and', "i'm", 'lazy', 'when', 'i'], ['to', 'say', '\n', 'there', 'is', 'a'], ['up', '\n', 'you', 'never', 'give', 'up'], ['but', 'i', "don't", '

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  15 => Generator loss: 1029.214233, Discriminator loss: -2077.953369 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['on', 'a', 'facade', '\n', "that's", 'what'], ['make', 'up', 'new', 'excuses', '\n', 'faking'], ['\n', 'and', 'a', 'girl', 'i', 'used'], ['is', 'fuck', 'all', 'to', 'do', '\n'], ['but', 'when', 'we', 'were', 'down', 'to'], ['can', 'choose', '\n', '\n', 'soon', 'as'], ['\n', 'i', 'will', 'never', 'see', 'when'], ['ever', 'get', 'to', 'leave', 'it', 'behind'], ['\n', 'when', 'you', 'are', 'out', 'there'], ['it', '\n', 'eazy', 'e', ':', "let's"], ['me', '\n', 'with', 'you', 'looking', 'so'], ['?', ')', '\n', "i'm", 'finna', 'pull'], ['\n', '\n', 'i', 'will', 'stop', 'and'], ['you', 'have', 'got', 'those', 'things', '\n'], ['red', ',', 'red', ',', 'red', ','], [',', 'defenceless', '\n', 'all', 'the', 'flowers'], ['royaume', 'des', 'vivants', 'et', 'des', 'morts'], ['i', 'mean', '\n', 'this', ',', 'watts'], ['we', 'were', 'drawn', 'away', ','

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  16 => Generator loss: 1090.707397, Discriminator loss: -2202.740234 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['i', 'try', 'to', 'pull', 'away', '\n'], ['\n', 'there', 'is', 'a', 'shadow', 'hanging'], ['the', 'tears', '\n', 'that', "i'm", 'cry'], ['the', 'rush', '\n', 'i', 'got', 'the'], ['doing', 'your', 'do', '\n', "that's", 'what'], ['hold', 'my', 'broken', 'heart', '\n', 'god'], ['truth', '\n', 'and', 'catch', 'a', 'glimpse'], ['there', 'and', 'then', 'were', 'gone', '\n'], ['you', 'are', 'living', 'in', 'the', 'past'], ['place', 'you', 'have', 'never', 'been', '\n'], ['me', 'honey', '\n', 'take', 'me', ','], ['are', '\n', 'buying', 'their', 'cocain', '\n'], [',', 'you', 'take', 'two', 'steps', ','], ['in', 'my', 'ears', '\n', 'my', 'heart'], ['york', '.', 'escape', 'from', 'la', '.'], ['she', 'just', 'makes', 'me', 'love', 'her'], ['\n', 'your', "protector's", 'coming', 'home', ','], ['state', '\n', 'he', 'was', 'born', 'with'], ['park', '\n

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  17 => Generator loss: 1151.964844, Discriminator loss: -2326.682617 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['\n', 'can', 'you', 'navigate', '\n', 'no'], ['to', 'suck', 'it', 'up', 'and', 'do'], ['\n', 'except', 'for', 'one', 'little', 'thing'], ['as', 'here', 'once', 'again', 'they', 'strut'], ['me', 'what', 'you', 'see', '?', '\n'], ['than', 'number', 'one', '?', '\n', 'you'], ['up', 'in', 'the', 'crib', 'with', 'us'], ['get', 'it', 'any', 'time', 'so', 'yes'], ['run', 'forever', '\n', 'you', 'are', 'dying'], ['you', 'have', 'been', 'picking', 'up', 'the'], ['and', "i'm", 'afraid', 'and', 'alone', 'with'], ['i', 'put', 'out', 'one', 'and', 'light'], ['hand', 'through', 'the', 'grey', 'doorway', 'at'], ["ain't", 'no', 'use', '\n', 'they', 'pile'], ['crash', 'when', 'i', 'come', 'through', '\n'], ['\n', 'da', 'whine', 'deh', 'a', 'one'], ['they', 'have', 'been', 'barking', 'up', 'the'], ['but', 'i', '\n', 'cannot', 'give', 'you'], [',', 'like',

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  18 => Generator loss: 1213.765137, Discriminator loss: -2451.107910 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[["i'm", 'not', 'that', 'young', 'anymore', '\n'], ['girls', '\n', "i'd", 'kick', 'it', 'with'], ['so', 'much', '\n', '\n', 'asleep', 'in'], ['dark', 'for', 'too', 'long', '\n', 'been'], ['of', 'us', 'and', 'turns', 'us', 'around'], ['skin', 'and', 'different', 'religion', '\n', 'i'], ['she', 'kissed', 'me', '\n', 'for', 'playing'], ['who', 'could', 'take', 'my', 'place', '\n'], ['on', 'christmas', 'morning', '\n', 'i', 'will'], ['the', 'scars', '\n', '\n', 'i', 'bet'], ['heard', 'but', 'by', 'night', 'in', 'our'], ['shore', '\n', 'and', 'then', 'and', 'there'], ['you', 'know', 'it', 'happens', 'every', 'day'], ['be', 'my', 'woman', ',', 'gal', ','], ['\n', 'that', 'old', 'swing', '\n', 'that'], ['near', 'future', '\n', 'please', 'make', 'sure'], ['\n', 'urging', 'to', 'free', 'my', 'kind'], ['field', '\n', 'may', 'this', 'be', 'the'], ['\

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  19 => Generator loss: 1275.011353, Discriminator loss: -2575.728027 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[['time', '\n', 'i', 'want', 'you', 'to'], ['\n', 'and', 'i', "haven't", 'been', 'the'], ['were', 'my', 'boat', 'in', 'the', 'deep'], ['we', 'alive', 'in', 'this', 'distant', 'dream'], ['wanna', 'go', '\n', 'go', 'without', 'you'], ['love', '\n', 'but', 'my', 'mind', 'holds'], ['the', 'slang', '\n', 'we', 'gang', 'bang'], ['\n', 'and', 'a', 'girl', 'i', 'used'], ['meanest', 'woman', 'i', 'have', 'most', 'ever'], ['toy', 'gun', 'sounds', '\n', 'grown', 'man'], ['in', 'the', 'world', 'she', 'meet', '\n'], ['?', '\n', 'alms', 'to', 'the', 'mother'], ['guessing', 'why', 'i', 'never', 'could', '\n'], ['sorry', ',', 'ya', '\n', 'slew', 'dem'], ['no', 'light', '\n', '\n', 'here', 'on'], ['i', 'were', 'broken', 'would', 'you', 'fix'], ['roll', '\n', 'i', 'confess', 'i', 'was'], ['all', '\n', 'just', 'let', 'it', 'out'], ['chance', 'and', 'changed'

  0%|          | 0/31 [00:00<?, ?it/s]

  Epoch  20 => Generator loss: 1336.964111, Discriminator loss: -2700.313721 : Generator accuracy: 0.000, Discriminator accuracy: 0.000
input words:	[[',', 'what', 'where', 'you', 'thinking', '?'], ['trying', 'to', 'do', '\n', 'some', 'days'], ['with', 'dreams', '\n', 'i', 'want', 'to'], ['and', 'smiling', ',', 'decent', ',', 'not'], ['woman', 'scream', 'in', 'your', 'face', '\n'], ['\n', 'and', 'if', 'you', 'wanna', 'do'], ['they', "ain't", 'coming', 'to', 'the', 'hood'], ['you', 'gave', 'me', 'shelter', 'in', 'your'], ['the', 'crown', '(', 'aah', ')', '\n'], ['like', 'others', 'of', 'my', 'station', '\n'], ['i', 'never', 'peeped', 'to', 'where', 'she'], ['\n', 'and', 'broke', 'it', 'apart', '\n'], ['strained', 'now', '\n', "'til", 'we', 'get'], ['you', '..', '\n', 'you', 'quietly', 'made'], ['\n', 'and', 'you', 'think', 'you', 'have'], ['for', 'sure', '\n', 'we', "can't", 'stand'], ['\n', 'endlessly', 'manifold', ',', 'self', 'contained'], ['and', 'while', 'we', 'sleep', '\n', 'we'],