# Install the required packages

In [1]:
!pip install wikipedia
!pip install nltk
!pip install pyspellchecker
!pip install networkx
!pip install dgl



In [2]:
!apt update
!apt install enchant --fix-missing
!apt install -qq enchant
!pip install pyenchant

Reading package lists... Done
E: Could not open lock file /var/lib/apt/lists/lock - open (13: Permission denied)
E: Unable to lock directory /var/lib/apt/lists/
W: Problem unlinking the file /var/cache/apt/pkgcache.bin - RemoveCaches (13: Permission denied)
W: Problem unlinking the file /var/cache/apt/srcpkgcache.bin - RemoveCaches (13: Permission denied)
E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?
E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?


# Import the required libraries

In [1]:
import os
import wikipedia
import nltk
from nltk import word_tokenize
from nltk.corpus import stopwords
from nltk.tokenize import RegexpTokenizer
from spellchecker import SpellChecker
import enchant
import dgl
import networkx as nx
import matplotlib.pyplot as plt
import itertools
import pickle
import torch
import scipy.sparse as sp
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# Set the embedding size and GPU device

In [2]:
EMBEDDING_DIMENSION = 64
device = torch.device('cuda')
torch.cuda.is_available()

True

In [3]:
nltk.download('punkt')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to /home/admin1/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/admin1/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

# Create a custom Wikipedia corpus

In [53]:
# The variable "keyword" decides what type of articles to pull from wikipedia
# We tested our approach on the following keywords: sports, celebrity, music

keyword = "music"
corpus_file = "data/" + keyword + "_corpus.pk"
corpus = ""

if os.path.isfile(corpus_file):
    corpus = pickle.load(open(corpus_file, "rb"))
else:

    search_results = wikipedia.search(keyword, results=1000)

    for result in search_results:
        try:
            summary = wikipedia.summary(result)
            corpus += summary
        except:
            print("Some Error occured for keyword : ", result)
        finally:
            pass
    pickle.dump(corpus, open(corpus_file, "wb"))

Some Error occured for keyword :  Music




  lis = BeautifulSoup(html).find_all('li')


Some Error occured for keyword :  Music (disambiguation)
Some Error occured for keyword :  Pop music
Some Error occured for keyword :  Sony Music
Some Error occured for keyword :  K-pop
Some Error occured for keyword :  Musical instrument
Some Error occured for keyword :  Composer
Some Error occured for keyword :  Record producer
Some Error occured for keyword :  Garage
Some Error occured for keyword :  20th-century music
Some Error occured for keyword :  Pop rock
Some Error occured for keyword :  Rebel Music
Some Error occured for keyword :  Jazz
Some Error occured for keyword :  Music group (disambiguation)
Some Error occured for keyword :  American Music Awards of 2002
Some Error occured for keyword :  Mode (music)
Some Error occured for keyword :  Roots music
Some Error occured for keyword :  English music
Some Error occured for keyword :  Musical notation
Some Error occured for keyword :  Music player
Some Error occured for keyword :  Techno
Some Error occured for keyword :  Roxy 

# Pre-process the text corpus

In [7]:
def remove_whitespace(text):
    return  " ".join(text.split())
 
def remove_stopwords(text_tokens):
    result = []
    en_stopwords = stopwords.words('english')
    
    for token in text_tokens:
        if token not in en_stopwords:
            result.append(token)

    return result

def remove_punctuations(text):
    tokenizer = RegexpTokenizer(r"\w+")
    lst=tokenizer.tokenize(' '.join(text))
    return lst

In [8]:
def pre_process_text(text_corpus):
    #Step 1: Convert every word to lower case to avoid any ambiguity
    lower_text_corpus = text_corpus.lower()
    
    #Step 2: Remove extra whitespaces
    lower_text_corpus = remove_whitespace(lower_text_corpus)
    
    #Step 3: Tokenize in word level
    tokens = word_tokenize(lower_text_corpus)
    
    #Step 4: Remove the stopwords
#     tokens = remove_stopwords(tokens)
    
    #Step 5: Remove punctuations
    tokens = remove_punctuations(tokens)
    return " ".join(tokens)
    

In [56]:
corpus[:1000]

'"Music! Music! Music! (Put Another Nickel In)" is a popular song written by Stephen Weiss and Bernie Baum and published in 1950."Music! Music! Music! (Put Another Nickel In)" is a popular song written by Stephen Weiss and Bernie Baum and published in 1950.Country (also called country and western) is a genre of popular music that originated with blues, church music such as Southern gospel and spirituals, old-time, and American folk music forms including Appalachian, Cajun, Creole, Hawaiian, and the cowboy Western music styles of New Mexico, Red Dirt, Tejano, and Texas country. Its popularized roots originate in the Southern and Southwestern United States of the early 1920s.\nCountry music often consists of ballads and honky-tonk dance tunes with generally simple form, folk lyrics, and harmonies often accompanied by string instruments such as electric and acoustic guitars, steel guitars (such as pedal steels and dobros), banjos, and fiddles as well as harmonicas. Blues modes have been u

In [57]:
processed_corpus = pre_process_text(corpus)

In [58]:
processed_corpus[:1000]

'music music music put another nickel in is a popular song written by stephen weiss and bernie baum and published in 1950 music music music put another nickel in is a popular song written by stephen weiss and bernie baum and published in 1950 country also called country and western is a genre of popular music that originated with blues church music such as southern gospel and spirituals old time and american folk music forms including appalachian cajun creole hawaiian and the cowboy western music styles of new mexico red dirt tejano and texas country its popularized roots originate in the southern and southwestern united states of the early 1920s country music often consists of ballads and honky tonk dance tunes with generally simple form folk lyrics and harmonies often accompanied by string instruments such as electric and acoustic guitars steel guitars such as pedal steels and dobros banjos and fiddles as well as harmonicas blues modes have been used extensively throughout its record

In [59]:
len(processed_corpus.split(" "))

76600

# Construct lookup for word-to-index and index-to-word

In [60]:
# word_count_map = dict()
word_index = dict()
index_word = dict()
processed_word_tokens = processed_corpus.split(" ")
unique_words = list(set(processed_word_tokens))

for idx, word in enumerate(unique_words):
    word_index[word] = idx
    index_word[idx] = word

In [61]:
len(word_index)

9344

# Construct a single large graph for the pre-processed corpus

In [62]:
src = []
dst = []

for idx, word in enumerate(processed_word_tokens):
    if idx < (len(processed_word_tokens)-1):
        src.append(word_index[word])
        dst.append(word_index[processed_word_tokens[idx+1]])

In [63]:
dgl_graph = dgl.DGLGraph()
dgl_graph.add_edges(src, dst)

In [64]:
print('We have %d nodes.' % dgl_graph.number_of_nodes())
print('We have %d edges.' % dgl_graph.number_of_edges())

We have 9344 nodes.
We have 76599 edges.


# Initialize the vector embeddings for each node in the graph

In [65]:
embedding_layer = torch.nn.Embedding(len(unique_words), EMBEDDING_DIMENSION)

In [66]:
input_emb = torch.LongTensor([range(0,len(unique_words),1)])
node_embeddings = None
with torch.no_grad():
    node_embeddings = embedding_layer(input_emb).squeeze()

In [67]:
node_embeddings.shape

torch.Size([9344, 64])

# Define Graph dataset as a DGLDataset

In [68]:
class WikipediaCorpusGraphDataset(dgl.data.DGLDataset):
    def __init__(self, src, dst, features):
        self.src = src
        self.dst = dst
        self.features = features
        self.graph = None
        super().__init__(name="wikipedia_corpus")
    
    def process(self):
        self.graph = dgl.graph((self.src, self.dst))
        self.graph.ndata['feat'] = self.features
        
    def edges(self):
        return self.src, self.dst
        

# Train Node Embeddings as Edge Prediction on Graph Dataset

In [69]:
graph_dataset = WikipediaCorpusGraphDataset(src, dst, node_embeddings)

In [70]:
graph = graph_dataset.graph
graph.number_of_edges()

76599

In [71]:
u, v = graph.edges()

eids = np.arange(graph.number_of_edges())
eids = np.random.permutation(eids)

test_size = int(len(eids) * 0.5)
train_size = graph.number_of_edges() - test_size

test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]

In [72]:
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(graph.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != 0)

neg_eids = np.random.choice(len(neg_u), graph.number_of_edges())
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]

In [73]:
train_g = dgl.remove_edges(graph, eids[:test_size])

In [74]:
train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=graph.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=graph.number_of_nodes())

test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=graph.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=graph.number_of_nodes())

# Define the GCN model

In [75]:
class GCNStack(nn.Module):
    def __init__(self, in_feats, h_feats, num_layers=2, dropout=0.25):
        super(GCNStack, self).__init__()
        self.convs = nn.ModuleList()
        self.convs.append(self.build_conv_model(in_feats, h_feats))
        
        for i in range(num_layers-1):
            self.convs.append(self.build_conv_model(h_feats, h_feats))
            
        self.num_layers = len(self.convs)
        self.lns = nn.ModuleList()
        for i in range(self.num_layers-1):
            self.lns.append(nn.LayerNorm(h_feats))
        
        self.dropout = dropout
           
    def build_conv_model(self, input_dim, hidden_dim):
        return GraphConv(input_dim, hidden_dim, allow_zero_in_degree=True)
    
    def forward(self, graph, input_features):
        x = input_features
        embeddings = None
        
        for i, conv_layer in enumerate(self.convs):
            x = conv_layer(graph, x)
            x = F.relu(x)
            embeddings = x
            x = F.dropout(x, p=self.dropout, training = self.training)
            
            if not i == self.num_layers-1:
                x = self.lns[i](x)
            
        return embeddings
    
    def loss(self, pred, label):
        return F.nll_loss(pred, label)
        

In [76]:
criterion = torch.nn.BCELoss()

In [77]:
def compute_loss(pos_score, neg_score):
#     print("Compute LOSS :: Positive Scores :: ", pos_score)
#     print("Compute LOSS :: Negative Scores :: ", neg_score)
    scores = torch.cat([pos_score.float(), neg_score.float()])
    pos_labels = []
    neg_labels = []
    for i in range(pos_score.shape[0]):
#         pos_labels.append([1,0])
        pos_labels.append([1])
    for i in range(neg_score.shape[0]):
#         neg_labels.append([0,1])
        neg_labels.append([0])
    labels = torch.cat([torch.Tensor(pos_labels), torch.Tensor(neg_labels)]).to(device)
#     print("Scores = ", scores)
#     print("Labels = ", labels)
    
#     return F.binary_cross_entropy_with_logits(scores, labels)
    return criterion(scores, labels)

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score.float(), neg_score.float()])
    pos_labels = []
    neg_labels = []
    for i in range(pos_score.shape[0]):
#         pos_labels.append([1,0])
        pos_labels.append([1])
    for i in range(neg_score.shape[0]):
#         neg_labels.append([0,1])
        neg_labels.append([0])
    labels = torch.cat([torch.Tensor(pos_labels), torch.Tensor(neg_labels)]).to(device)
    return roc_auc_score(labels.detach().cpu().numpy(), scores.detach().cpu().numpy())

def compute_ap(pos_score, neg_score):
    scores = torch.cat([pos_score.float(), neg_score.float()])
    pos_labels = []
    neg_labels = []
    for i in range(pos_score.shape[0]):
#         pos_labels.append([1,0])
        pos_labels.append([1])
    for i in range(neg_score.shape[0]):
#         neg_labels.append([0,1])
        neg_labels.append([0])
    labels = torch.cat([torch.Tensor(pos_labels), torch.Tensor(neg_labels)]).to(device)
    return average_precision_score(labels.detach().cpu().numpy(), scores.detach().cpu().numpy())

In [78]:
def train_edge_classification(nc_model,linear_pred_edges, dgl_dataset, epochs=100):
    optimizer = torch.optim.Adam(nc_model.parameters(), lr=0.01)
    optimizer_linear_pred = torch.optim.Adam(linear_pred_edges[0].parameters(), lr=0.01)
    
    nc_model = nc_model.to(device)
    linear_pred_edges = linear_pred_edges.to(device)
    
    nc_model.train()
    linear_pred_edges.train()
    
    graph = dgl_dataset.graph.to(device)
    
    best_val_acc = 0
    best_test_acc = 0

    features = graph.ndata['feat']
    node_embeddings = None
    edges_src, edges_dst = graph.edges()
    train_pos_edges_src, train_pos_edges_dst = train_pos_g.edges()
    train_neg_edges_src, train_neg_edges_dst = train_neg_g.edges()
    
    test_pos_edges_src, test_pos_edges_dst = test_pos_g.edges()
    test_neg_edges_src, test_neg_edges_dst = test_neg_g.edges()
    
    for e in tqdm(range(epochs)):
        num_correct = 0
        num_tests = 0
        
        node_embeddings = nc_model(graph, features)

        edge_embeddings_pos = (node_embeddings[train_pos_edges_src] * node_embeddings[train_pos_edges_dst])/2
        edge_embeddings_neg = (node_embeddings[train_neg_edges_src] * node_embeddings[train_neg_edges_dst])/2
        
        test_edge_embeddings_pos = (node_embeddings[test_pos_edges_src] * node_embeddings[test_pos_edges_dst])/2
        test_edge_embeddings_neg = (node_embeddings[test_neg_edges_src] * node_embeddings[test_neg_edges_dst])/2
    
        pred_edges_pos = linear_pred_edges(edge_embeddings_pos).to(device)
        pred_edges_neg = linear_pred_edges(edge_embeddings_neg).to(device)

        loss = compute_loss(pred_edges_pos, pred_edges_neg)
        
        optimizer.zero_grad()
        optimizer_linear_pred.zero_grad()

        loss.backward()
        optimizer.step()
        optimizer_linear_pred.step()
        
        if e % 10 == 0:
            print('In epoch {}, loss: {}'.format(e, loss))
            print('AUC Score -> ', compute_auc(pred_edges_pos, pred_edges_neg))
            print('AP Score  -> ', compute_ap(pred_edges_pos, pred_edges_neg))

    with torch.no_grad():
        pos_score = linear_pred_edges(test_edge_embeddings_pos).to(device)
        neg_score = linear_pred_edges(test_edge_embeddings_neg).to(device)
        print('Test AUC = ', compute_auc(pos_score, neg_score))
        print('Test AP  = ', compute_ap(pos_score, neg_score))
        
    return node_embeddings

# Train Edge Classification / Link Prediction

In [79]:
num_feats = graph_dataset.graph.ndata['feat'].shape[1]
nc_model = GCNStack(num_feats, EMBEDDING_DIMENSION, num_layers=2)

In [80]:
linear_pred_edges = torch.nn.Sequential(torch.nn.Linear(EMBEDDING_DIMENSION,1), torch.nn.Sigmoid())
word_embeddings = train_edge_classification(nc_model,linear_pred_edges, graph_dataset, epochs=200)

  0%|                                                   | 0/200 [00:00<?, ?it/s]

In epoch 0, loss: 0.6822222471237183
AUC Score ->  0.6762261103422889
AP Score  ->  0.7369128088844534


  4%|█▉                                         | 9/200 [00:01<00:19,  9.56it/s]

In epoch 10, loss: 0.49734437465667725
AUC Score ->  0.9245177555917621


  6%|██▋                                       | 13/200 [00:02<00:21,  8.81it/s]

AP Score  ->  0.9372830845092832


 10%|████▍                                     | 21/200 [00:02<00:21,  8.38it/s]

In epoch 20, loss: 0.45951351523399353
AUC Score ->  0.9418889961074108
AP Score  ->  0.9516554808079793


 15%|██████▎                                   | 30/200 [00:03<00:14, 11.60it/s]

In epoch 30, loss: 0.4162270426750183
AUC Score ->  0.9393495170053651


 16%|██████▋                                   | 32/200 [00:04<00:20,  8.16it/s]

AP Score  ->  0.9543930216964814


 20%|████████▍                                 | 40/200 [00:04<00:15, 10.39it/s]

In epoch 40, loss: 0.3654099404811859
AUC Score ->  0.9350390117868416
AP Score  ->  0.9538003565584503


 24%|██████████▎                               | 49/200 [00:05<00:14, 10.51it/s]

In epoch 50, loss: 0.31927812099456787
AUC Score ->  0.9405832219866521


 26%|███████████▏                              | 53/200 [00:06<00:17,  8.51it/s]

AP Score  ->  0.9573640289001606


 30%|████████████▌                             | 60/200 [00:06<00:12, 11.50it/s]

In epoch 60, loss: 0.28653910756111145
AUC Score ->  0.9526941464595164
AP Score  ->  0.9637897036149945


 35%|██████████████▋                           | 70/200 [00:07<00:13,  9.74it/s]

In epoch 70, loss: 0.26090943813323975
AUC Score ->  0.9608710690644833
AP Score  ->  0.9680528349321953


 40%|█████████████████                         | 81/200 [00:09<00:12,  9.49it/s]

In epoch 80, loss: 0.24568793177604675
AUC Score ->  0.9649054066767107
AP Score  ->  0.9704014610602435


 46%|███████████████████                       | 91/200 [00:10<00:11,  9.28it/s]

In epoch 90, loss: 0.2395978569984436
AUC Score ->  0.9667778275808001
AP Score  ->  0.9713089253074805


 50%|████████████████████▋                    | 101/200 [00:11<00:11,  8.73it/s]

In epoch 100, loss: 0.23226088285446167
AUC Score ->  0.9683986750881115
AP Score  ->  0.9725336610838561


 56%|██████████████████████▊                  | 111/200 [00:12<00:10,  8.51it/s]

In epoch 110, loss: 0.23098817467689514
AUC Score ->  0.9690164637430209
AP Score  ->  0.9729315005810908


 60%|████████████████████████▍                | 119/200 [00:12<00:07, 11.47it/s]

In epoch 120, loss: 0.2242017239332199
AUC Score ->  0.9707251252650164
AP Score  ->  0.9742070015898018


 66%|██████████████████████████▊              | 131/200 [00:14<00:08,  8.48it/s]

In epoch 130, loss: 0.22203992307186127
AUC Score ->  0.9712476501305483
AP Score  ->  0.9746051943353633


 70%|████████████████████████████▉            | 141/200 [00:15<00:07,  7.82it/s]

In epoch 140, loss: 0.21795400977134705
AUC Score ->  0.9719706494692855
AP Score  ->  0.9751982052691142


 76%|██████████████████████████████▉          | 151/200 [00:16<00:05,  8.81it/s]

In epoch 150, loss: 0.21605145931243896
AUC Score ->  0.9723762333917336
AP Score  ->  0.9754768248381445


 80%|████████████████████████████████▊        | 160/200 [00:17<00:04,  9.75it/s]

In epoch 160, loss: 0.21587936580181122
AUC Score ->  0.9728457597365856
AP Score  ->  0.9758021572851515


 85%|██████████████████████████████████▊      | 170/200 [00:18<00:02, 11.43it/s]

In epoch 170, loss: 0.21081021428108215
AUC Score ->  0.9735773865115993
AP Score  ->  0.97644273189574


 90%|████████████████████████████████████▉    | 180/200 [00:19<00:01, 10.83it/s]

In epoch 180, loss: 0.21041707694530487
AUC Score ->  0.9737635606623537
AP Score  ->  0.9766473246730882


 95%|██████████████████████████████████████▉  | 190/200 [00:20<00:00, 10.26it/s]

In epoch 190, loss: 0.20679619908332825
AUC Score ->  0.9745665574787477


 96%|███████████████████████████████████████▎ | 192/200 [00:20<00:01,  7.56it/s]

AP Score  ->  0.9772354546070432


100%|█████████████████████████████████████████| 200/200 [00:21<00:00,  9.32it/s]


Test AUC =  0.9645492760261468
Test AP  =  0.9699718507665834


In [81]:
word_embeddings.shape

torch.Size([9344, 64])

In [82]:
word_embeddings[0]

tensor([0.7975, 0.0000, 0.4177, 0.6892, 1.1205, 0.0000, 0.9303, 1.1190, 2.3185,
        0.0000, 0.8899, 0.0000, 0.0000, 0.0000, 0.1875, 0.0000, 0.6177, 0.0000,
        0.0000, 0.7564, 0.9341, 0.0000, 1.9227, 0.0000, 1.3086, 1.3907, 0.0000,
        0.0000, 1.3917, 0.3986, 0.0000, 0.0000, 0.6023, 0.9117, 0.0000, 0.9524,
        0.6193, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9112, 0.6968, 1.1294,
        0.4323, 0.0000, 0.0000, 0.8983, 0.9505, 0.0000, 0.4560, 0.1150, 0.7290,
        0.1269, 0.2383, 0.0000, 0.0942, 0.0000, 1.6153, 0.8349, 0.0000, 0.1989,
        0.9258], device='cuda:0', grad_fn=<SelectBackward0>)

# Use the Word Embeddings from GCN for LSTM

### 1. Create n-gram sequences from the original corpus

In [83]:
one_gram_sequences = []
two_gram_sequences = []
three_gram_sequences = []
five_gram_sequences = []
ten_gram_sequences = []
processed_corpus_tokens = processed_corpus.split(" ")

corpus_with_indices = [word_index[w] for w in processed_corpus_tokens]

for idx, word in enumerate(corpus_with_indices):
    if idx+11 < len(processed_corpus_tokens):
        one_gram_sequences.append(processed_corpus_tokens[idx:idx+2])
        two_gram_sequences.append(processed_corpus_tokens[idx:idx+3])
        three_gram_sequences.append(processed_corpus_tokens[idx:idx+4])
        five_gram_sequences.append(processed_corpus_tokens[idx:idx+6])
        ten_gram_sequences.append(processed_corpus_tokens[idx:idx+11])

In [84]:
zero_array = np.zeros((EMBEDDING_DIMENSION))
# zero_array

In [85]:
zero_array = np.zeros((EMBEDDING_DIMENSION))
def word_seq_to_padded_embeddings(word_seq, word_index, max_seq_length=3):
    index_seq = []
    for seq in word_seq:
        embedding_list = []
        additional_zero_pads = max_seq_length - len(seq)
        for word in seq:
            idx = word_index[word]
            embedding_list.append(word_embeddings[idx].cpu().detach().numpy())
        for i in range(additional_zero_pads):
            embedding_list.append(zero_array)
        index_seq.append(np.array(embedding_list))
    return index_seq
    

In [86]:
def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return np.eye(num_classes, dtype='uint8')[y]

In [87]:
from torch.nn.utils.rnn import pad_sequence

In [88]:
X = []
Y = []
X.extend(word_seq_to_padded_embeddings(np.array(one_gram_sequences)[:,:-1], word_index, 10))
y_one_gram_indices = [word_index[word] for word in np.array(one_gram_sequences)[:,-1]]
y_one_gram_onehot = to_categorical(y_one_gram_indices, len(unique_words))
Y.extend(y_one_gram_onehot)

X.extend(word_seq_to_padded_embeddings(np.array(two_gram_sequences)[:,:-1], word_index, 10))
y_two_gram_indices = [word_index[word] for word in np.array(two_gram_sequences)[:,-1]]
y_two_gram_onehot = to_categorical(y_two_gram_indices, len(unique_words))
Y.extend(y_two_gram_onehot)

X.extend(word_seq_to_padded_embeddings(np.array(three_gram_sequences)[:,:-1], word_index, 10))
y_three_gram_indices = [word_index[word] for word in np.array(three_gram_sequences)[:,-1]]
y_three_gram_onehot = to_categorical(y_three_gram_indices, len(unique_words))
Y.extend(y_three_gram_onehot)

X.extend(word_seq_to_padded_embeddings(np.array(five_gram_sequences)[:,:-1], word_index, 10))
y_five_gram_indices = [word_index[word] for word in np.array(five_gram_sequences)[:,-1]]
y_five_gram_onehot = to_categorical(y_five_gram_indices, len(unique_words))
Y.extend(y_five_gram_onehot)

X.extend(word_seq_to_padded_embeddings(np.array(ten_gram_sequences)[:,:-1], word_index, 10))
y_ten_gram_indices = [word_index[word] for word in np.array(ten_gram_sequences)[:,-1]]
y_ten_gram_onehot = to_categorical(y_ten_gram_indices, len(unique_words))
Y.extend(y_ten_gram_onehot)

In [89]:
print(len(X))
print(len(Y))

382945
382945


### 2. Create 80/20 split for training and testing

In [90]:
X_train, X_test, Y_train, Y_test = train_test_split( X, Y, test_size=0.2, random_state=42)
X_val, X_test, Y_val, Y_test = train_test_split(X_test, Y_test, test_size=0.1, random_state=40)

In [91]:
from torch.utils.data import TensorDataset, DataLoader

cross_entropy = torch.nn.CrossEntropyLoss()
train_data = TensorDataset(torch.from_numpy(np.array(X_train)), torch.from_numpy(np.array(Y_train)))
val_data = TensorDataset(torch.from_numpy(np.array(X_val)), torch.from_numpy(np.array(Y_val)))
test_data = TensorDataset(torch.from_numpy(np.array(X_test)), torch.from_numpy(np.array(Y_test)))

batch_size = 100

train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size,drop_last=True)
val_loader = DataLoader(val_data, shuffle=True, batch_size=batch_size,drop_last=True)
test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size,drop_last=True)

# Define the Next word prediction LSTM model

In [92]:
class NextWordPredModel(torch.nn.Module):
    def __init__(self, embedding_dimension, hidden_units, num_classes):
        super().__init__()
        self.input_size = embedding_dimension
        self.hidden_units = hidden_units
        self.num_classes = num_classes
        self.n_layers=2
        
        self.rnn = torch.nn.LSTM(input_size = self.input_size, hidden_size=self.hidden_units,
                                 num_layers=self.n_layers, dropout=0.2, batch_first=True)
        self.fc = torch.nn.Linear(self.hidden_units, self.num_classes)
        
    def forward(self, x, state = None):
        h, rnn_state = self.rnn(x, state)
        h = self.fc(h)
        scores = F.log_softmax(h, dim=1)
        return scores[:,-1,:], rnn_state
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.n_layers, batch_size, self.hidden_units).zero_().to(device),
                      weight.new(self.n_layers, batch_size, self.hidden_units).zero_().to(device))
        return hidden

In [93]:
def train_model(model, epochs, lr=0.005):
    cross_entropy = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    model.train()
    for ep in tqdm(range(epochs)):
        init_state = model.init_hidden(batch_size)
        train_losses=[]
        num_correct = 0
        total = 0
        h = tuple([e.data for e in init_state])
        for idx, (inp_seq, op_word) in enumerate(train_loader):
            
            inp_seq, op_word = inp_seq.float().to(device), op_word.float().to(device)
            optimizer.zero_grad()

            prediction, (h_state, c_state) = model(inp_seq, h)
            h = (h_state.detach(), c_state.detach())
            
            loss = cross_entropy(prediction, op_word)
            num_correct += prediction.argmax(dim=1).eq(op_word.argmax(dim=1)).sum().item()
            total += op_word.shape[0]
            train_losses.append(loss.item())
    #         print("Loss = ", loss.item())
            loss.backward()
            optimizer.step()
        if (ep+1)%10 == 0:
            val_h = model.init_hidden(batch_size)
            val_losses = []
            val_num_correct = 0
            val_total = 0
            model.eval()
            val_h = tuple([each.data for each in val_h])
            
            for inp, lab in val_loader:
                
                inp, lab = inp.to(device), lab.to(device)
                out, (val_h_state, val_c_state) = model(inp.float(), val_h)
                val_h = (val_h_state.detach(), val_c_state.detach())
                
                val_num_correct += out.argmax(dim=1).eq(lab.argmax(dim=1)).sum().item()
                val_total += lab.shape[0]
                val_loss = cross_entropy(out, lab.float())
                val_losses.append(val_loss.item())
                
            model.train()
            if (num_correct >0):
                print("At least one correct :: num_correct = ", num_correct)
            train_acc = (num_correct / total)
            val_acc = (val_num_correct / val_total)
#             print(train_losses)
            print("Epoch: {}/{}...".format(ep+1, epochs),
                  "Step: {}...".format(idx),
                  "Train Loss: {:.6f}...".format(np.mean(train_losses)),
                  "Train Acc: {:.3f}...".format(train_acc),
                  "Val Loss: {:.6f}".format(np.mean(val_losses)),
                  "Val Acc: {:.3f}...".format(val_acc))

# Train the LSTM Model

In [427]:
# Training block for celebrity corpus
# Train Accuracy = 61.3% Val Accuracy = 33.2%
# model = NextWordPredModel(embedding_dimension=EMBEDDING_DIMENSION, hidden_units=256, num_classes=len(unique_words))
# train_model(model.to(device), 500, 0.0005)

  2%|▊                                       | 10/500 [02:55<2:26:53, 17.99s/it]

At least one correct :: num_correct =  33687
Epoch: 10/500... Step: 2298... Train Loss: 4.827836... Train Acc: 0.147... Val Loss: 5.064383 Val Acc: 0.148...


  4%|█▌                                      | 20/500 [05:54<2:26:23, 18.30s/it]

At least one correct :: num_correct =  54679
Epoch: 20/500... Step: 2298... Train Loss: 3.775284... Train Acc: 0.238... Val Loss: 4.421593 Val Acc: 0.202...


  6%|██▍                                     | 30/500 [08:54<2:24:02, 18.39s/it]

At least one correct :: num_correct =  71063
Epoch: 30/500... Step: 2298... Train Loss: 3.219056... Train Acc: 0.309... Val Loss: 4.164583 Val Acc: 0.232...


  8%|███▏                                    | 40/500 [11:55<2:21:33, 18.47s/it]

At least one correct :: num_correct =  82759
Epoch: 40/500... Step: 2298... Train Loss: 2.880222... Train Acc: 0.360... Val Loss: 4.062173 Val Acc: 0.248...


 10%|████                                    | 50/500 [14:56<2:18:28, 18.46s/it]

At least one correct :: num_correct =  91541
Epoch: 50/500... Step: 2298... Train Loss: 2.650597... Train Acc: 0.398... Val Loss: 4.001685 Val Acc: 0.263...


 12%|████▊                                   | 60/500 [17:58<2:15:58, 18.54s/it]

At least one correct :: num_correct =  97573
Epoch: 60/500... Step: 2298... Train Loss: 2.487294... Train Acc: 0.424... Val Loss: 3.983474 Val Acc: 0.274...


 14%|█████▌                                  | 70/500 [20:59<2:12:18, 18.46s/it]

At least one correct :: num_correct =  102959
Epoch: 70/500... Step: 2298... Train Loss: 2.365567... Train Acc: 0.448... Val Loss: 3.979365 Val Acc: 0.282...


 16%|██████▍                                 | 80/500 [24:00<2:08:42, 18.39s/it]

At least one correct :: num_correct =  106976
Epoch: 80/500... Step: 2298... Train Loss: 2.269219... Train Acc: 0.465... Val Loss: 3.975545 Val Acc: 0.290...


 18%|███████▏                                | 90/500 [27:00<2:05:14, 18.33s/it]

At least one correct :: num_correct =  110627
Epoch: 90/500... Step: 2298... Train Loss: 2.189664... Train Acc: 0.481... Val Loss: 3.995829 Val Acc: 0.295...


 20%|███████▊                               | 100/500 [30:00<2:02:15, 18.34s/it]

At least one correct :: num_correct =  113262
Epoch: 100/500... Step: 2298... Train Loss: 2.127627... Train Acc: 0.493... Val Loss: 4.004541 Val Acc: 0.299...


 22%|████████▌                              | 110/500 [32:59<1:58:38, 18.25s/it]

At least one correct :: num_correct =  115413
Epoch: 110/500... Step: 2298... Train Loss: 2.074093... Train Acc: 0.502... Val Loss: 4.037704 Val Acc: 0.301...


 24%|█████████▎                             | 120/500 [35:58<1:55:48, 18.28s/it]

At least one correct :: num_correct =  117846
Epoch: 120/500... Step: 2298... Train Loss: 2.027914... Train Acc: 0.513... Val Loss: 4.062194 Val Acc: 0.302...


 26%|██████████▏                            | 130/500 [38:57<1:52:19, 18.22s/it]

At least one correct :: num_correct =  119491
Epoch: 130/500... Step: 2298... Train Loss: 1.988959... Train Acc: 0.520... Val Loss: 4.072759 Val Acc: 0.308...


 28%|██████████▉                            | 140/500 [41:56<1:49:33, 18.26s/it]

At least one correct :: num_correct =  120589
Epoch: 140/500... Step: 2298... Train Loss: 1.957577... Train Acc: 0.525... Val Loss: 4.092851 Val Acc: 0.307...


 30%|███████████▋                           | 150/500 [44:55<1:46:18, 18.23s/it]

At least one correct :: num_correct =  121945
Epoch: 150/500... Step: 2298... Train Loss: 1.925914... Train Acc: 0.530... Val Loss: 4.107124 Val Acc: 0.312...


 32%|████████████▍                          | 160/500 [47:54<1:43:13, 18.22s/it]

At least one correct :: num_correct =  123763
Epoch: 160/500... Step: 2298... Train Loss: 1.892598... Train Acc: 0.538... Val Loss: 4.141226 Val Acc: 0.314...


 34%|█████████████▎                         | 170/500 [50:53<1:40:41, 18.31s/it]

At least one correct :: num_correct =  124753
Epoch: 170/500... Step: 2298... Train Loss: 1.869160... Train Acc: 0.543... Val Loss: 4.173375 Val Acc: 0.316...


 36%|██████████████                         | 180/500 [53:53<1:37:43, 18.32s/it]

At least one correct :: num_correct =  126158
Epoch: 180/500... Step: 2298... Train Loss: 1.846999... Train Acc: 0.549... Val Loss: 4.184699 Val Acc: 0.315...


 38%|██████████████▊                        | 190/500 [56:52<1:34:45, 18.34s/it]

At least one correct :: num_correct =  126725
Epoch: 190/500... Step: 2298... Train Loss: 1.829384... Train Acc: 0.551... Val Loss: 4.207950 Val Acc: 0.316...


 40%|███████████████▌                       | 200/500 [59:53<1:32:05, 18.42s/it]

At least one correct :: num_correct =  127534
Epoch: 200/500... Step: 2298... Train Loss: 1.812188... Train Acc: 0.555... Val Loss: 4.219138 Val Acc: 0.317...


 42%|███████████████▌                     | 210/500 [1:02:54<1:29:06, 18.43s/it]

At least one correct :: num_correct =  128431
Epoch: 210/500... Step: 2298... Train Loss: 1.794707... Train Acc: 0.559... Val Loss: 4.232137 Val Acc: 0.318...


 44%|████████████████▎                    | 220/500 [1:05:55<1:26:20, 18.50s/it]

At least one correct :: num_correct =  129913
Epoch: 220/500... Step: 2298... Train Loss: 1.770338... Train Acc: 0.565... Val Loss: 4.266524 Val Acc: 0.318...


 46%|█████████████████                    | 230/500 [1:08:57<1:23:11, 18.49s/it]

At least one correct :: num_correct =  130267
Epoch: 230/500... Step: 2298... Train Loss: 1.762234... Train Acc: 0.567... Val Loss: 4.266339 Val Acc: 0.321...


 48%|█████████████████▊                   | 240/500 [1:11:58<1:20:11, 18.50s/it]

At least one correct :: num_correct =  131088
Epoch: 240/500... Step: 2298... Train Loss: 1.741728... Train Acc: 0.570... Val Loss: 4.285909 Val Acc: 0.324...


 50%|██████████████████▌                  | 250/500 [1:14:59<1:16:59, 18.48s/it]

At least one correct :: num_correct =  131390
Epoch: 250/500... Step: 2298... Train Loss: 1.731206... Train Acc: 0.572... Val Loss: 4.325960 Val Acc: 0.322...


 52%|███████████████████▏                 | 260/500 [1:18:01<1:14:08, 18.53s/it]

At least one correct :: num_correct =  132195
Epoch: 260/500... Step: 2298... Train Loss: 1.718839... Train Acc: 0.575... Val Loss: 4.320296 Val Acc: 0.323...


 54%|███████████████████▉                 | 270/500 [1:21:02<1:10:44, 18.45s/it]

At least one correct :: num_correct =  132988
Epoch: 270/500... Step: 2298... Train Loss: 1.703533... Train Acc: 0.578... Val Loss: 4.326797 Val Acc: 0.320...


 56%|████████████████████▋                | 280/500 [1:24:04<1:08:02, 18.56s/it]

At least one correct :: num_correct =  132860
Epoch: 280/500... Step: 2298... Train Loss: 1.702639... Train Acc: 0.578... Val Loss: 4.345622 Val Acc: 0.326...


 58%|█████████████████████▍               | 290/500 [1:27:06<1:04:49, 18.52s/it]

At least one correct :: num_correct =  133500
Epoch: 290/500... Step: 2298... Train Loss: 1.689760... Train Acc: 0.581... Val Loss: 4.378185 Val Acc: 0.325...


 60%|██████████████████████▏              | 300/500 [1:30:08<1:01:52, 18.56s/it]

At least one correct :: num_correct =  134334
Epoch: 300/500... Step: 2298... Train Loss: 1.676381... Train Acc: 0.584... Val Loss: 4.391687 Val Acc: 0.323...


 62%|████████████████████████▏              | 310/500 [1:33:09<58:27, 18.46s/it]

At least one correct :: num_correct =  134763
Epoch: 310/500... Step: 2298... Train Loss: 1.668475... Train Acc: 0.586... Val Loss: 4.403092 Val Acc: 0.325...


 64%|████████████████████████▉              | 320/500 [1:36:11<55:32, 18.51s/it]

At least one correct :: num_correct =  135248
Epoch: 320/500... Step: 2298... Train Loss: 1.656698... Train Acc: 0.588... Val Loss: 4.411882 Val Acc: 0.324...


 66%|█████████████████████████▋             | 330/500 [1:39:12<52:24, 18.50s/it]

At least one correct :: num_correct =  135917
Epoch: 330/500... Step: 2298... Train Loss: 1.643023... Train Acc: 0.591... Val Loss: 4.444364 Val Acc: 0.328...


 68%|██████████████████████████▌            | 340/500 [1:42:14<49:27, 18.55s/it]

At least one correct :: num_correct =  136140
Epoch: 340/500... Step: 2298... Train Loss: 1.643336... Train Acc: 0.592... Val Loss: 4.441659 Val Acc: 0.327...


 70%|███████████████████████████▎           | 350/500 [1:45:16<46:20, 18.53s/it]

At least one correct :: num_correct =  136744
Epoch: 350/500... Step: 2298... Train Loss: 1.631197... Train Acc: 0.595... Val Loss: 4.445152 Val Acc: 0.326...


 72%|████████████████████████████           | 360/500 [1:48:18<43:09, 18.49s/it]

At least one correct :: num_correct =  136841
Epoch: 360/500... Step: 2298... Train Loss: 1.626862... Train Acc: 0.595... Val Loss: 4.465076 Val Acc: 0.328...


 74%|████████████████████████████▊          | 370/500 [1:51:18<39:49, 18.38s/it]

At least one correct :: num_correct =  137032
Epoch: 370/500... Step: 2298... Train Loss: 1.618307... Train Acc: 0.596... Val Loss: 4.470961 Val Acc: 0.330...


 76%|█████████████████████████████▋         | 380/500 [1:54:18<36:42, 18.35s/it]

At least one correct :: num_correct =  137806
Epoch: 380/500... Step: 2298... Train Loss: 1.606089... Train Acc: 0.599... Val Loss: 4.501866 Val Acc: 0.329...


 78%|██████████████████████████████▍        | 390/500 [1:57:18<33:27, 18.25s/it]

At least one correct :: num_correct =  138113
Epoch: 390/500... Step: 2298... Train Loss: 1.602644... Train Acc: 0.601... Val Loss: 4.489742 Val Acc: 0.330...


 80%|███████████████████████████████▏       | 400/500 [2:00:18<30:33, 18.33s/it]

At least one correct :: num_correct =  138297
Epoch: 400/500... Step: 2298... Train Loss: 1.596902... Train Acc: 0.602... Val Loss: 4.518561 Val Acc: 0.331...


 82%|███████████████████████████████▉       | 410/500 [2:03:17<27:21, 18.24s/it]

At least one correct :: num_correct =  138321
Epoch: 410/500... Step: 2298... Train Loss: 1.591284... Train Acc: 0.602... Val Loss: 4.522983 Val Acc: 0.329...


 84%|████████████████████████████████▊      | 420/500 [2:06:16<24:22, 18.28s/it]

At least one correct :: num_correct =  138938
Epoch: 420/500... Step: 2298... Train Loss: 1.584429... Train Acc: 0.604... Val Loss: 4.539976 Val Acc: 0.331...


 86%|█████████████████████████████████▌     | 430/500 [2:09:15<21:14, 18.21s/it]

At least one correct :: num_correct =  139355
Epoch: 430/500... Step: 2298... Train Loss: 1.577017... Train Acc: 0.606... Val Loss: 4.535849 Val Acc: 0.332...


 88%|██████████████████████████████████▎    | 440/500 [2:12:14<18:16, 18.27s/it]

At least one correct :: num_correct =  139397
Epoch: 440/500... Step: 2298... Train Loss: 1.574910... Train Acc: 0.606... Val Loss: 4.564084 Val Acc: 0.330...


 90%|███████████████████████████████████    | 450/500 [2:15:13<15:13, 18.27s/it]

At least one correct :: num_correct =  139844
Epoch: 450/500... Step: 2298... Train Loss: 1.567097... Train Acc: 0.608... Val Loss: 4.551934 Val Acc: 0.333...


 92%|███████████████████████████████████▉   | 460/500 [2:18:13<12:15, 18.40s/it]

At least one correct :: num_correct =  139903
Epoch: 460/500... Step: 2298... Train Loss: 1.565761... Train Acc: 0.609... Val Loss: 4.593003 Val Acc: 0.332...


 94%|████████████████████████████████████▋  | 470/500 [2:21:14<09:11, 18.39s/it]

At least one correct :: num_correct =  140709
Epoch: 470/500... Step: 2298... Train Loss: 1.552564... Train Acc: 0.612... Val Loss: 4.595441 Val Acc: 0.332...


 96%|█████████████████████████████████████▍ | 480/500 [2:24:15<06:09, 18.48s/it]

At least one correct :: num_correct =  140884
Epoch: 480/500... Step: 2298... Train Loss: 1.550345... Train Acc: 0.613... Val Loss: 4.600771 Val Acc: 0.331...


 98%|██████████████████████████████████████▏| 490/500 [2:27:16<03:04, 18.45s/it]

At least one correct :: num_correct =  140987
Epoch: 490/500... Step: 2298... Train Loss: 1.547027... Train Acc: 0.613... Val Loss: 4.597930 Val Acc: 0.333...


100%|███████████████████████████████████████| 500/500 [2:30:17<00:00, 18.04s/it]

At least one correct :: num_correct =  140968
Epoch: 500/500... Step: 2298... Train Loss: 1.543993... Train Acc: 0.613... Val Loss: 4.619086 Val Acc: 0.332...





In [94]:
# Training Block for Music corpus
model = NextWordPredModel(embedding_dimension=EMBEDDING_DIMENSION, hidden_units=256, num_classes=len(unique_words))
train_model(model.to(device), 500, 0.0005)

  2%|▊                                       | 10/500 [04:19<3:37:34, 26.64s/it]

At least one correct :: num_correct =  40176
Epoch: 10/500... Step: 3062... Train Loss: 5.045858... Train Acc: 0.131... Val Loss: 5.305742 Val Acc: 0.133...


  4%|█▌                                      | 20/500 [08:45<3:37:07, 27.14s/it]

At least one correct :: num_correct =  60930
Epoch: 20/500... Step: 3062... Train Loss: 4.159292... Train Acc: 0.199... Val Loss: 4.822711 Val Acc: 0.168...


  6%|██▍                                     | 30/500 [13:13<3:33:40, 27.28s/it]

At least one correct :: num_correct =  78152
Epoch: 30/500... Step: 3062... Train Loss: 3.677666... Train Acc: 0.255... Val Loss: 4.624567 Val Acc: 0.191...


  8%|███▏                                    | 40/500 [17:41<3:29:32, 27.33s/it]

At least one correct :: num_correct =  90309
Epoch: 40/500... Step: 3062... Train Loss: 3.371451... Train Acc: 0.295... Val Loss: 4.532623 Val Acc: 0.209...


 10%|████                                    | 50/500 [22:09<3:24:35, 27.28s/it]

At least one correct :: num_correct =  99356
Epoch: 50/500... Step: 3062... Train Loss: 3.162378... Train Acc: 0.324... Val Loss: 4.490783 Val Acc: 0.219...


 12%|████▊                                   | 60/500 [26:38<3:20:30, 27.34s/it]

At least one correct :: num_correct =  106594
Epoch: 60/500... Step: 3062... Train Loss: 3.006875... Train Acc: 0.348... Val Loss: 4.451553 Val Acc: 0.228...


 14%|█████▌                                  | 70/500 [31:06<3:15:26, 27.27s/it]

At least one correct :: num_correct =  112228
Epoch: 70/500... Step: 3062... Train Loss: 2.891007... Train Acc: 0.366... Val Loss: 4.436703 Val Acc: 0.235...


 16%|██████▍                                 | 80/500 [35:36<3:12:08, 27.45s/it]

At least one correct :: num_correct =  117363
Epoch: 80/500... Step: 3062... Train Loss: 2.796609... Train Acc: 0.383... Val Loss: 4.441587 Val Acc: 0.240...


 18%|███████▏                                | 90/500 [40:05<3:07:01, 27.37s/it]

At least one correct :: num_correct =  120915
Epoch: 90/500... Step: 3062... Train Loss: 2.717921... Train Acc: 0.395... Val Loss: 4.459000 Val Acc: 0.245...


 20%|███████▊                               | 100/500 [44:34<3:02:57, 27.44s/it]

At least one correct :: num_correct =  124495
Epoch: 100/500... Step: 3062... Train Loss: 2.654053... Train Acc: 0.406... Val Loss: 4.444401 Val Acc: 0.249...


 22%|████████▌                              | 110/500 [49:07<3:00:27, 27.76s/it]

At least one correct :: num_correct =  127043
Epoch: 110/500... Step: 3062... Train Loss: 2.603478... Train Acc: 0.415... Val Loss: 4.462943 Val Acc: 0.255...


 24%|█████████▎                             | 120/500 [53:41<2:57:09, 27.97s/it]

At least one correct :: num_correct =  129825
Epoch: 120/500... Step: 3062... Train Loss: 2.552301... Train Acc: 0.424... Val Loss: 4.476026 Val Acc: 0.259...


 26%|██████████▏                            | 130/500 [58:17<2:53:33, 28.14s/it]

At least one correct :: num_correct =  132019
Epoch: 130/500... Step: 3062... Train Loss: 2.512420... Train Acc: 0.431... Val Loss: 4.479403 Val Acc: 0.263...


 28%|██████████▎                          | 140/500 [1:02:54<2:49:53, 28.32s/it]

At least one correct :: num_correct =  133914
Epoch: 140/500... Step: 3062... Train Loss: 2.475772... Train Acc: 0.437... Val Loss: 4.499234 Val Acc: 0.264...


 30%|███████████                          | 150/500 [1:07:45<2:54:12, 29.86s/it]

At least one correct :: num_correct =  135577
Epoch: 150/500... Step: 3062... Train Loss: 2.449784... Train Acc: 0.443... Val Loss: 4.521499 Val Acc: 0.266...


 32%|███████████▊                         | 160/500 [1:12:37<2:48:47, 29.79s/it]

At least one correct :: num_correct =  137908
Epoch: 160/500... Step: 3062... Train Loss: 2.411895... Train Acc: 0.450... Val Loss: 4.531948 Val Acc: 0.270...


 34%|████████████▌                        | 170/500 [1:17:37<2:48:08, 30.57s/it]

At least one correct :: num_correct =  139494
Epoch: 170/500... Step: 3062... Train Loss: 2.386275... Train Acc: 0.455... Val Loss: 4.552912 Val Acc: 0.270...


 36%|█████████████▎                       | 180/500 [1:22:34<2:45:17, 30.99s/it]

At least one correct :: num_correct =  140733
Epoch: 180/500... Step: 3062... Train Loss: 2.361122... Train Acc: 0.459... Val Loss: 4.549980 Val Acc: 0.273...


 38%|██████████████                       | 190/500 [1:27:39<2:39:02, 30.78s/it]

At least one correct :: num_correct =  141919
Epoch: 190/500... Step: 3062... Train Loss: 2.340881... Train Acc: 0.463... Val Loss: 4.566618 Val Acc: 0.275...


 40%|██████████████▊                      | 200/500 [1:32:34<2:29:54, 29.98s/it]

At least one correct :: num_correct =  143358
Epoch: 200/500... Step: 3062... Train Loss: 2.315273... Train Acc: 0.468... Val Loss: 4.571307 Val Acc: 0.279...


 42%|███████████████▌                     | 210/500 [1:37:18<2:21:13, 29.22s/it]

At least one correct :: num_correct =  144712
Epoch: 210/500... Step: 3062... Train Loss: 2.295955... Train Acc: 0.472... Val Loss: 4.587423 Val Acc: 0.279...


 44%|████████████████▎                    | 220/500 [1:41:57<2:12:31, 28.40s/it]

At least one correct :: num_correct =  145334
Epoch: 220/500... Step: 3062... Train Loss: 2.278777... Train Acc: 0.474... Val Loss: 4.586944 Val Acc: 0.283...


 46%|█████████████████                    | 230/500 [1:46:35<2:07:25, 28.32s/it]

At least one correct :: num_correct =  146695
Epoch: 230/500... Step: 3062... Train Loss: 2.260685... Train Acc: 0.479... Val Loss: 4.607208 Val Acc: 0.283...


 48%|█████████████████▊                   | 240/500 [1:51:13<2:02:26, 28.25s/it]

At least one correct :: num_correct =  147606
Epoch: 240/500... Step: 3062... Train Loss: 2.245784... Train Acc: 0.482... Val Loss: 4.608990 Val Acc: 0.284...


 50%|██████████████████▌                  | 250/500 [1:55:50<1:57:27, 28.19s/it]

At least one correct :: num_correct =  148404
Epoch: 250/500... Step: 3062... Train Loss: 2.232591... Train Acc: 0.485... Val Loss: 4.633955 Val Acc: 0.288...


 52%|███████████████████▏                 | 260/500 [2:00:26<1:52:29, 28.12s/it]

At least one correct :: num_correct =  149367
Epoch: 260/500... Step: 3062... Train Loss: 2.214894... Train Acc: 0.488... Val Loss: 4.646677 Val Acc: 0.288...


 54%|███████████████████▉                 | 270/500 [2:05:04<1:48:14, 28.24s/it]

At least one correct :: num_correct =  149690
Epoch: 270/500... Step: 3062... Train Loss: 2.209621... Train Acc: 0.489... Val Loss: 4.665132 Val Acc: 0.289...


 56%|████████████████████▋                | 280/500 [2:09:40<1:43:00, 28.09s/it]

At least one correct :: num_correct =  150801
Epoch: 280/500... Step: 3062... Train Loss: 2.191867... Train Acc: 0.492... Val Loss: 4.679539 Val Acc: 0.291...


 58%|█████████████████████▍               | 290/500 [2:14:17<1:38:42, 28.20s/it]

At least one correct :: num_correct =  151762
Epoch: 290/500... Step: 3062... Train Loss: 2.177273... Train Acc: 0.495... Val Loss: 4.696355 Val Acc: 0.293...


 60%|██████████████████████▏              | 300/500 [2:18:53<1:33:51, 28.16s/it]

At least one correct :: num_correct =  152714
Epoch: 300/500... Step: 3062... Train Loss: 2.166719... Train Acc: 0.499... Val Loss: 4.704591 Val Acc: 0.291...


 62%|██████████████████████▉              | 310/500 [2:23:33<1:31:05, 28.77s/it]

At least one correct :: num_correct =  153230
Epoch: 310/500... Step: 3062... Train Loss: 2.155856... Train Acc: 0.500... Val Loss: 4.701663 Val Acc: 0.294...


 64%|███████████████████████▋             | 320/500 [2:28:17<1:26:27, 28.82s/it]

At least one correct :: num_correct =  153482
Epoch: 320/500... Step: 3062... Train Loss: 2.144790... Train Acc: 0.501... Val Loss: 4.714164 Val Acc: 0.296...


 66%|████████████████████████▍            | 330/500 [2:33:13<1:24:46, 29.92s/it]

At least one correct :: num_correct =  154307
Epoch: 330/500... Step: 3062... Train Loss: 2.135937... Train Acc: 0.504... Val Loss: 4.722053 Val Acc: 0.297...


 68%|█████████████████████████▏           | 340/500 [2:38:18<1:23:28, 31.31s/it]

At least one correct :: num_correct =  154532
Epoch: 340/500... Step: 3062... Train Loss: 2.127518... Train Acc: 0.505... Val Loss: 4.749752 Val Acc: 0.296...


 70%|█████████████████████████▉           | 350/500 [2:43:27<1:18:25, 31.37s/it]

At least one correct :: num_correct =  155484
Epoch: 350/500... Step: 3062... Train Loss: 2.117427... Train Acc: 0.508... Val Loss: 4.747963 Val Acc: 0.299...


 72%|██████████████████████████▋          | 360/500 [2:48:39<1:14:12, 31.80s/it]

At least one correct :: num_correct =  156005
Epoch: 360/500... Step: 3062... Train Loss: 2.108857... Train Acc: 0.509... Val Loss: 4.756731 Val Acc: 0.301...


 74%|███████████████████████████▍         | 370/500 [2:53:48<1:07:19, 31.07s/it]

At least one correct :: num_correct =  156679
Epoch: 370/500... Step: 3062... Train Loss: 2.097504... Train Acc: 0.512... Val Loss: 4.749269 Val Acc: 0.301...


 76%|█████████████████████████████▋         | 380/500 [2:58:39<59:18, 29.66s/it]

At least one correct :: num_correct =  157298
Epoch: 380/500... Step: 3062... Train Loss: 2.091279... Train Acc: 0.514... Val Loss: 4.788580 Val Acc: 0.297...


 78%|██████████████████████████████▍        | 390/500 [3:03:33<54:21, 29.65s/it]

At least one correct :: num_correct =  157594
Epoch: 390/500... Step: 3062... Train Loss: 2.084248... Train Acc: 0.515... Val Loss: 4.789634 Val Acc: 0.302...


 80%|███████████████████████████████▏       | 400/500 [3:08:22<49:20, 29.60s/it]

At least one correct :: num_correct =  157921
Epoch: 400/500... Step: 3062... Train Loss: 2.073718... Train Acc: 0.516... Val Loss: 4.800413 Val Acc: 0.304...


 82%|███████████████████████████████▉       | 410/500 [3:13:12<44:07, 29.41s/it]

At least one correct :: num_correct =  158377
Epoch: 410/500... Step: 3062... Train Loss: 2.066575... Train Acc: 0.517... Val Loss: 4.800718 Val Acc: 0.305...


 84%|████████████████████████████████▊      | 420/500 [3:17:53<39:08, 29.36s/it]

At least one correct :: num_correct =  158809
Epoch: 420/500... Step: 3062... Train Loss: 2.063128... Train Acc: 0.518... Val Loss: 4.825942 Val Acc: 0.305...


 86%|█████████████████████████████████▌     | 430/500 [3:22:37<33:18, 28.54s/it]

At least one correct :: num_correct =  159381
Epoch: 430/500... Step: 3062... Train Loss: 2.057430... Train Acc: 0.520... Val Loss: 4.833952 Val Acc: 0.305...


 88%|██████████████████████████████████▎    | 440/500 [3:27:23<29:42, 29.70s/it]

At least one correct :: num_correct =  158354
Epoch: 440/500... Step: 3062... Train Loss: 2.069957... Train Acc: 0.517... Val Loss: 4.834384 Val Acc: 0.305...


 90%|███████████████████████████████████    | 450/500 [3:32:06<23:50, 28.61s/it]

At least one correct :: num_correct =  159733
Epoch: 450/500... Step: 3062... Train Loss: 2.048972... Train Acc: 0.521... Val Loss: 4.839302 Val Acc: 0.308...


 92%|███████████████████████████████████▉   | 460/500 [3:36:50<19:30, 29.27s/it]

At least one correct :: num_correct =  160405
Epoch: 460/500... Step: 3062... Train Loss: 2.038199... Train Acc: 0.524... Val Loss: 4.856903 Val Acc: 0.306...


 94%|████████████████████████████████████▋  | 470/500 [3:41:33<14:46, 29.54s/it]

At least one correct :: num_correct =  160913
Epoch: 470/500... Step: 3062... Train Loss: 2.032770... Train Acc: 0.525... Val Loss: 4.872825 Val Acc: 0.305...


 96%|█████████████████████████████████████▍ | 480/500 [3:46:23<10:04, 30.23s/it]

At least one correct :: num_correct =  161122
Epoch: 480/500... Step: 3062... Train Loss: 2.026696... Train Acc: 0.526... Val Loss: 4.873804 Val Acc: 0.307...


 98%|██████████████████████████████████████▏| 490/500 [3:51:25<05:09, 30.92s/it]

At least one correct :: num_correct =  161804
Epoch: 490/500... Step: 3062... Train Loss: 2.015396... Train Acc: 0.528... Val Loss: 4.895746 Val Acc: 0.309...


100%|███████████████████████████████████████| 500/500 [3:56:31<00:00, 28.38s/it]

At least one correct :: num_correct =  161938
Epoch: 500/500... Step: 3062... Train Loss: 2.010200... Train Acc: 0.529... Val Loss: 4.889572 Val Acc: 0.308...





# Compute Test Accuracy

In [95]:
import math

test_losses = []
num_correct = 0
total = 0
h = model.init_hidden(batch_size)

model.eval()
h = tuple([each.data for each in h])
for inputs, labels in test_loader:
    
    inputs, labels = inputs.to(device), labels.to(device)
    output, (test_h_state, test_c_state) = model(inputs.float(), h)
    h = (test_h_state.detach(), test_c_state.detach())

    pred = output.argmax(dim=1)
    correct = pred.eq(labels.argmax(dim=1)).sum().item()

    num_correct += correct
    total += inputs.shape[0]
test_acc = round(num_correct/total, 3)
print("Test Accuracy = ", test_acc)

Test Accuracy =  0.309


# Save checkpoints for the trained models

In [96]:
from datetime import datetime
now = datetime.now().strftime("%Y-%m-%d::%H:%M:%S")

PATH = 'model-' + keyword + "-" + str(len(X)) + "-" + now + "-" + str(test_acc) + ".pt"

torch.save({
            'model_state_dict': model.state_dict()
            }, PATH)
pickle.dump(word_embeddings, open("word_embeddings_" + keyword + "_" + now + ".pk", "wb"))

In [None]:
input_sequence = word_seq_to_padded_embeddings(np.array(word_tokens), word_index, 5).unsqueeze(dim=1)

# LOAD TRAINED MODEL

In [33]:
MODEL_PATH = "model-287430-2022-11-21::10:22:40-0.322"
word_embeddings = pickle.load(open("word_embeddings_2022-11-21::10:22:40.pk", "rb"))

model = NextWordPredModel(embedding_dimension=EMBEDDING_DIMENSION, hidden_units=256, num_classes=len(unique_words))
checkpoint = torch.load(MODEL_PATH)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

NextWordPredModel(
  (rnn): LSTM(64, 256, num_layers=2, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=256, out_features=7901, bias=True)
)

# TEST THE TRAINED MODEL

In [97]:
def predict_next_word(word_model, sentence):
    words = sentence.strip().split()
    word_tokens = []
    for w in words:
        if w in word_index:
            word_tokens.append(word_index[w])
    
    input_sequence = word_embeddings[word_tokens].unsqueeze(dim=0)
    prediction, _ = word_model(input_sequence.to(device))
    predicted_wordindex = prediction.argmax(dim=1).item()
    predicted_word = index_word[predicted_wordindex]
    return predicted_word
    

In [98]:
processed_corpus[500:1000]

' styles of new mexico red dirt tejano and texas country its popularized roots originate in the southern and southwestern united states of the early 1920s country music often consists of ballads and honky tonk dance tunes with generally simple form folk lyrics and harmonies often accompanied by string instruments such as electric and acoustic guitars steel guitars such as pedal steels and dobros banjos and fiddles as well as harmonicas blues modes have been used extensively throughout its recorde'

In [100]:
pred_word = predict_next_word(model.to(device), "the celebrities arrived at the")
print("Next Word prediction = ", pred_word)

Next Word prediction =  ownership


In [109]:
pred_word = predict_next_word(model.to(device), "the guitar player was")
print("Next Word prediction = ", pred_word)

Next Word prediction =  spontaneously


In [111]:
pred_word = predict_next_word(model.to(device), "the song browser")
print("Next Word prediction = ", pred_word)

Next Word prediction =  depends


In [122]:
pred_word = predict_next_word(model.to(device), "rock music is rated very high in")
print("Next Word prediction = ", pred_word)

Next Word prediction =  opeth


In [134]:
sentence = "the main man was"
for i in range(10):
    predct = predict_next_word(model, sentence)
    sentence = sentence + " " + predct
print(sentence)

the main man was lyons deeper lighter frostbite maximum strong and minute hiatus hawaiian
