In [1]:
using Pkg; Pkg.activate("/home/dhairyagandhi96/temp/model-zoo/script/.."); Pkg.status();

using Flux
using Flux: onehot, onehotbatch, crossentropy, reset!, throttle
using Statistics: mean
using Random
using Unicode
include("scrape.jl")
corpora = Dict()

cd(@__DIR__)
for file in readdir("corpus")
    lang = Symbol(match(r"(.*)\.txt", file).captures[1])
    corpus = split(String(read("corpus/$file")), ".")
    corpus = strip.(Unicode.normalize.(corpus, casefold=true, stripmark=true))
    corpus = filter(!isempty, corpus)
    corpora[lang] = corpus
end

langs = collect(keys(corpora))
alphabet = ['a':'z'; '0':'9'; ' '; '\n'; '_']

    Status `~/temp/model-zoo/Project.toml`
  [1520ce14]   AbstractTrees v0.2.1
  [fbb218c0] ↑ BSON v0.2.3 ⇒ v0.2.4
  [54eefc05]   Cascadia v0.4.0
  [8f4d0f93]   Conda v1.3.0
  [864edb3b] ↑ DataStructures v0.17.0 ⇒ v0.17.5
  [31c24e10] ↑ Distributions v0.21.3 ⇒ v0.21.5
  [587475ba]   Flux v0.9.0
  [708ec375]   Gumbo v0.5.1
  [b0807396]   Gym v1.1.3
  [cd3eb016] ↑ HTTP v0.8.6 ⇒ v0.8.7
  [6218d12a]   ImageMagick v0.7.5
  [916415d5]   Images v0.18.0
  [e5e0dc1b]   Juno v0.7.2
  [ca7b5df7]   MFCC v0.3.1
  [dbeba491] + Metalhead v0.4.0 #c4d1eba (https://github.com/FluxML/Metalhead.jl.git)
  [91a5bcdd] ↑ Plots v0.26.3 ⇒ v0.27.0
  [2913bbd2]   StatsBase v0.32.0
  [98b73d46]   Trebuchet v0.1.0
  [8149f6b0] ↑ WAV v1.0.2 ⇒ v1.0.3
  [10745b16]   Statistics 
  [4ec0a83e]   Unicode 


39-element Array{Char,1}:
 'a' 
 'b' 
 'c' 
 'd' 
 'e' 
 'f' 
 'g' 
 'h' 
 'i' 
 'j' 
 ⋮   
 '4' 
 '5' 
 '6' 
 '7' 
 '8' 
 '9' 
 ' ' 
 '\n'
 '_' 

See which chars will be represented as "unknown"

In [2]:
unique(filter(x -> x ∉ alphabet, join(vcat(values(corpora)...))))

dataset = [(onehotbatch(s, alphabet, '_'), onehot(l, langs))
           for l in langs for s in corpora[l]] |> shuffle

train, test = dataset[1:end-100], dataset[end-99:end]

N = 15

scanner = Chain(Dense(length(alphabet), N, σ), LSTM(N, N))
encoder = Dense(N, length(langs))

function model(x)
    state = scanner.(x.data)[end]
    reset!(scanner)
    softmax(encoder(state))
end

loss(x, y) = crossentropy(model(x), y)

testloss() = mean(loss(t...) for t in test)

opt = ADAM()
ps = params(scanner, encoder)
evalcb = () -> @show testloss()

Flux.train!(loss, ps, train, opt, cb = throttle(evalcb, 10))

testloss() = 1.6321858f0 (tracked)
testloss() = 1.5675306f0 (tracked)
testloss() = 1.5592979f0 (tracked)
testloss() = 1.5649234f0 (tracked)
testloss() = 1.5684868f0 (tracked)
testloss() = 1.5389409f0 (tracked)
testloss() = 1.5401595f0 (tracked)
testloss() = 1.5568221f0 (tracked)
testloss() = 1.5433629f0 (tracked)
testloss() = 1.5213215f0 (tracked)
testloss() = 1.5453626f0 (tracked)
testloss() = 1.4517982f0 (tracked)
testloss() = 1.4952666f0 (tracked)
testloss() = 1.4635943f0 (tracked)
testloss() = 1.4445903f0 (tracked)
testloss() = 1.410103f0 (tracked)
testloss() = 1.4697593f0 (tracked)
testloss() = 1.3583783f0 (tracked)
testloss() = 1.2699927f0 (tracked)
testloss() = 1.3471761f0 (tracked)
testloss() = 1.4361632f0 (tracked)
testloss() = 1.3016007f0 (tracked)
testloss() = 1.3721728f0 (tracked)
testloss() = 1.3477708f0 (tracked)
testloss() = 1.4562292f0 (tracked)
testloss() = 1.4022727f0 (tracked)
testloss() = 1.353598f0 (tracked)
testloss() = 1.3427544f0 (tracked)
testloss() = 1.3287576