In [1]:
# Author: Arman Kabiri
# Date: Feb. 27, 2020
# Email: Arman.Kabiri94@fmail.com

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import gensim
import logging
from tqdm import tqdm
import typing

%matplotlib inline

In [5]:
class LanguageModel(nn.Module):
    
    def __init__(self, seq_length:int, n_layers:int, hidden_size:int, n_vocab:int, input_size:int, dropout:float,
                 bidirectional:bool, tie_weights:bool=False, pret_emb_matrix:np.array=None, trainable_emb:bool=True):
        
        super.__init__(self)
        
        #Initializing Embedding Layer
        if pret_emb_matrix is not None:
            assert n_vocab == pret_emb_matrix.shape[0] and input_size == pret_emb_matrix.shape[1]
        
        self.embedding = nn.Embedding(n_vocab, input_size)
       
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        #initializing training layer
        self.dropout = nn.Dropout(p=dropout)
        
        self.rnn = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=n_layers,
                          dropout=dropout, bidirectional=bidirectional)
        
        self.decoder = nn.Linear(hidden_size, n_vocab)
        
        # Network Initialization
        init_weights(self, pret_emb_matrix, trainable_emb, tie_weights)
        
        
    def init_weights(self, pret_emb_matrix, trainable_emb, tie_weights):

        if tie_weights and ~trainable_emb:
            raise ValueError('tie_weights and trainable_emb flags should be used in a compatible way.')
            
        if pret_emb_matrix==None and trainable_emb==False:
            raise ValueError('When pre-trained embeddings are not given, weights should be trainable.')
        
        initrange = 0.1
        
        if pret_emb_matrix != None:
            self.embedding.load_state_dict({'weight': pret_emb_matrix})
            self.embedding.weight.requires_grad = trainable_emb            
        else:
            self.embedding.weight.data.uniform_(-initrange, initrange)
                    
        self.decoder.bias.data.zero_()                
        if tie_weights:
            if self.hidden_size != self.input_size:
                raise ValueError('When using the tied flag, hidden_size must be equal to input_size')
            self.decoder.weight = self.embedding.weight
        
        else:
            self.decoder.weight.data.uniform_(-initrange, initrange)