In [1]:
using Flux

init_weights(dims...) = rand(Float32, dims...) .* sqrt(1/1150)

a = LSTM(400, 1150; init = init_weights)
# Initializing wegiths between [-1/sqrt(H), 1/sqrt(H)] H = hidden size

LanguageModel

In [2]:
mutable struct LanguageModel
    embedMat :: TrackedArray
    lstmLayer1 :: Flux.Recur
    lstmLayer2 :: Flux.Recur
    lstmLayer3 :: Flux.Recur
    embedDropProb :: Float64
    wordDropout :: Dropout
    hiddenDropout :: Dropout
    LayerDropProb :: Float64
    FinalDropProb :: Float64
    RecurrentLayers :: Chain

    function LanguageModel(;embedDropProb::Float64 = 0.05, wordDropProb::Float64 = 0.4, hidDropProb::Float64 = 0.5, LayerDropProb::Float64 = 0.3, FinalDropProb::Float64 = 0.4)
        lm = new(
            param(rand(10, 5) .* 0.1),
            LSTM(400, 1150; init_weights),
            LSTM(1150, 1150),
            LSTM(1150, 400),
            embedDropProb,
            Dropout(wordDropProb),
            Dropout(hidDropProb),
            LayerDropProb,
            FinalDropProb
        )

        lm.lstmLayer1.cell.Wh = param(rand(Float32, size(lm.lstmLayer1.cell.Wh)) .* sqrt(1/1150))
        lm.lstmLayer2.cell.Wh = param(rand(Float32, size(lm.lstmLayer2.cell.Wh)) .* sqrt(1/1150))
        lm.lstmLayer3.cell.Wh = param(rand(Float32, size(lm.lstmLayer3.cell.Wh)) .* sqrt(1/1150))
        lm.lstmLayer1.cell.Wi = param(rand(Float32, size(lm.lstmLayer1.cell.Wi)) .* sqrt(1/1150))
        lm.lstmLayer2.cell.Wi = param(rand(Float32, size(lm.lstmLayer2.cell.Wi)) .* sqrt(1/1150))
        lm.lstmLayer3.cell.Wi = param(rand(Float32, size(lm.lstmLayer3.cell.Wi)) .* sqrt(1/1150))

        lm.RecurrentLayers = Chain(
            lm.lstmLayer1,
            Dropout(LayerDropProb),
            lm.lstmLayer2,
            Dropout(LayerDropProb),
            lm.lstmLayer3,
            Dropout(FinalDropProb)
        )

        return lm
    end
end

LanguageModel([0.0941891 0.0551141 … 0.0658747 0.0672926; 0.0956207 0.0253594 … 0.0320328 0.0894783; … ; 0.0235151 0.000167877 … 0.059368 0.0668135; 0.0271961 0.0374485 … 0.0983765 0.0844762] (tracked), Recur(LSTMCell(400, 1150)), Recur(LSTMCell(1150, 1150)), Recur(LSTMCell(1150, 400)))