In [1]:
using Flux
using Flux: onehot, onehotbatch, crossentropy, reset!, throttle

In [20]:
corpora = Dict()

for file in readdir("data/lang")
  lang = Symbol(match(r"(.*)\.txt", file).captures[1])
  corpus = split(String(read("data/lang/$file")), ".")
  corpus = strip.(normalize_string.(corpus, casefold=true, stripmark=true))
  corpus = filter(!isempty, corpus)
  corpora[lang] = corpus
end

corpora

Dict{Any,Any} with 5 entries:
  :en => String["wikipedia (/ˌwɪkɪˈpiːdiə/ ( listen)wik-i-pee-dee-ə or /ˌwɪkiˈp…
  :it => String["wikipedia (pronuncia: vedi sotto) e un'enciclopediaonline a co…
  :fr => String["wikipedia ecouter est un projet d'encyclopedie universelle, mu…
  :es => String["wikipedia es una enciclopedialibre,[nota 2]\u200bpoliglota y e…
  :da => String["wikipedia er en encyklopædi med abent indhold, skrevet i samar…

In [21]:
langs = collect(keys(corpora))

5-element Array{Any,1}:
 :en
 :it
 :fr
 :es
 :da

In [22]:
alphabet = ['a':'z'; '0':'9'; ' '; '\n'; '_'];

In [23]:
# See which chars will be represented as "unknown"
unique(filter(x -> x ∉ alphabet, join(vcat(values(corpora)...))))

148-element Array{Char,1}:
 '('
 '/'
 'ˌ'
 'ɪ'
 'ˈ'
 'ː'
 'ə'
 ' '
 ')'
 '-'
 '['
 ']'
 ','
 ⋮  
 'ব'
 'ল'
 'দ'
 'শ'
 'চ'
 'ট'
 'ম'
 'ঢ'
 'ক'
 'খ'
 'হ'
 'স'

In [24]:
dataset = [(onehotbatch(s, alphabet, '_'), onehot(l, langs))
           for l in langs for s in corpora[l]] |> shuffle
        

8284-element Array{Tuple{Flux.OneHotMatrix{Array{Flux.OneHotVector,1}},Flux.OneHotVector},1}:
 (Bool[false false … true false; false false … false false; … ; false false … false false; false false … false true], Bool[false, false, false, true, false])  
 (Bool[false false … false true; false false … false false; … ; false false … false false; false false … false false], Bool[false, false, false, true, false]) 
 (Bool[false false; false false; … ; false false; true true], Bool[false, false, true, false, false])                                                          
 (Bool[false false … false false; false false … false false; … ; false false … false false; false false … true false], Bool[false, false, false, true, false]) 
 (Bool[true false … false false; false false … false false; … ; false false … false false; false false … false true], Bool[false, false, true, false, false])  
 (Bool[false false … false true; false false … false false; … ; false false … false false; false false … f

In [25]:
train, test = dataset[1:end-100], dataset[end-99:end];

In [26]:
N = 15

scanner = Chain(Dense(length(alphabet), N, σ), LSTM(N, N))
encoder = Dense(N, length(langs))

function model(x)
  state = scanner.(x.data)[end]
  reset!(scanner)
  softmax(encoder(state))
end

loss(x, y) = crossentropy(model(x), y)

loss (generic function with 1 method)

In [27]:
testloss() = mean(loss(t...) for t in test)
opt = ADAM(params(scanner, encoder))
evalcb = () -> @show testloss()

(::#27) (generic function with 1 method)

In [28]:
Flux.train!(loss, train, opt, cb = throttle(evalcb, 10));

testloss() = 1.7428789251582644 (tracked)
testloss() = 1.5618301730245043 (tracked)
testloss() = 1.535118687642855 (tracked)
testloss() = 1.5908612062229408 (tracked)
testloss() = 1.5365635495624956 (tracked)
testloss() = 1.5579613625184499 (tracked)
testloss() = 1.5456924576384798 (tracked)
testloss() = 1.5242060208609067 (tracked)
testloss() = 1.5477429165332373 (tracked)
testloss() = 1.525041853689614 (tracked)
testloss() = 1.6352322728114996 (tracked)
testloss() = 1.5892076730862938 (tracked)
testloss() = 1.5479989052454894 (tracked)
testloss() = 1.528102725465964 (tracked)
testloss() = 1.5420772173602697 (tracked)
testloss() = 1.5142100940304297 (tracked)
testloss() = 1.5148853761441379 (tracked)
testloss() = 1.5073587878221621 (tracked)
testloss() = 1.5024050250698098 (tracked)
testloss() = 1.5263587579538793 (tracked)
testloss() = 1.503265196174578 (tracked)
testloss() = 1.4937646573117385 (tracked)
testloss() = 1.5007020557881545 (tracked)
testloss() = 1.4817318556684045 (track

In [29]:
model(onehotbatch(normalize_string("c'é una bella filosofia", casefold=true, stripmark=true),
        alphabet, '_'))

Tracked 5-element Array{Float64,1}:
 0.0714032
 0.297722 
 0.0470815
 0.536765 
 0.0470282

In [30]:
using Interact, Plots

In [31]:
predict(s) =
    isempty(s) ?
        softmax(ones(length(langs))) :
        model(onehotbatch(normalize_string(s, casefold=true, stripmark=true), alphabet, '_')).data

predict (generic function with 1 method)

In [32]:
@manipulate for s = "c'é una bella filosofia"
    bar(String.(langs), predict(s),
        label=["Probability"], ylims=(0,1))
end