In [None]:
using Flux
using Flux: onehot, onehotbatch, crossentropy, reset!, throttle

In [None]:
corpora = Dict()

for file in readdir("data/corpus")
  lang = Symbol(match(r"(.*)\.txt", file).captures[1])
  corpus = split(String(read("data/corpus/$file")), ".")
  corpus = strip.(normalize_string.(corpus, casefold=true, stripmark=true))
  corpus = filter(!isempty, corpus)
  corpora[lang] = corpus
end

corpora

In [None]:
langs = collect(keys(corpora))
alphabet = ['a':'z'; '0':'9'; ' '; '\n'; '_'];

In [None]:
# See which chars will be represented as "unknown"
unique(filter(x -> x ∉ alphabet, join(vcat(values(corpora)...))))

In [None]:
dataset = [(onehotbatch(s, alphabet, '_'), onehot(l, langs))
           for l in langs for s in corpora[l]] |> shuffle
            
train, test = dataset[1:end-100], dataset[end-99:end];

In [None]:
N = 15

scanner = Chain(Dense(length(alphabet), N, σ), LSTM(N, N))
encoder = Dense(N, length(langs))

function model(x)
  state = scanner.(x.data)[end]
  reset!(scanner)
  softmax(encoder(state))
end

loss(x, y) = crossentropy(model(x), y)

In [None]:
testloss() = mean(loss(t...) for t in test)
opt = ADAM(params(scanner, encoder))
evalcb = () -> @show testloss()

In [None]:
Flux.train!(loss, train, opt, cb = throttle(evalcb, 10))

In [None]:
using Interact, Plots

In [None]:
predict(s) =
    isempty(s) ?
        softmax(ones(length(langs))) :
        model(onehotbatch(normalize_string(s, casefold=true, stripmark=true), alphabet, '_')).data

In [None]:
@manipulate for s = "c'é una bella filosofia"
    bar(String.(langs), predict(s),
        label=["Probability"], ylims=(0,1))
end