In [1]:
using Flux, CuArrays
using Flux: onehot, onehotbatch, crossentropy, reset!, throttle, batch, batchseq

In [2]:
corpora = Dict()

for file in readdir("corpus")
  lang = Symbol(match(r"(.*)\.txt", file).captures[1])
  corpus = split(String(read("corpus/$file")), ".")
  corpus = strip.(normalize_string.(corpus, casefold=true, stripmark=true))
  corpus = filter(!isempty, corpus)
  corpora[lang] = corpus
end

corpora

Dict{Any,Any} with 5 entries:
  :en => String["wikipedia (/ˌwɪkɪˈpiːdiə/ ( listen)wik-i-pee-dee-ə or /ˌwɪkiˈp…
  :it => String["wikipedia (pronuncia: vedi sotto) e un'enciclopediaonline a co…
  :fr => String["wikipediaecouter est un projet d'encyclopedie universelle, mul…
  :es => String["wikipedia es una enciclopedialibre,[nota 2]\u200bpoliglota y e…
  :da => String["wikipedia er en encyklopædi med abent indhold, skrevet i samar…

In [3]:
langs = collect(keys(corpora))
alphabet = ['a':'z'; '0':'9'; ' '; '\n'; '_'; '\0'];

In [4]:
# See which chars will be represented as "unknown"
unique(filter(x -> x ∉ alphabet, join(vcat(values(corpora)...))))

152-element Array{Char,1}:
 '('
 '/'
 'ˌ'
 'ɪ'
 'ˈ'
 'ː'
 'ə'
 ' '
 ')'
 '-'
 '['
 ']'
 ','
 ⋮  
 'ব'
 'ল'
 'দ'
 'শ'
 'চ'
 'ট'
 'ম'
 'ঢ'
 'ক'
 'খ'
 'হ'
 'স'

In [5]:
dataset = [(onehotbatch(s, alphabet, '_').data, onehot(l, langs))
           for l in langs for s in corpora[l]]
dataset = sort(dataset, by = x -> length(x[1]))
dataset = [(batchseq(map(x->x[1],dataset[i]), onehot('\0', alphabet)), batch(map(x->x[2], dataset[i])))
           for i in Iterators.partition(1:length(dataset), 50)] |> shuffle
train, test = dataset[1:end-5], dataset[end-5+1:end];

In [6]:
N = 15

scanner = cu(Chain(Dense(length(alphabet), N, σ), LSTM(N, N)))
encoder = cu(Dense(N, length(langs)))

function model(x)
  state = scanner.(cu.(collect.(x)))[end]
  reset!(scanner)
  softmax(encoder(state))
end

loss(x, y) = crossentropy(model(x), cu(collect(y)))

loss (generic function with 1 method)

In [7]:
testloss() = mean(loss(t...) for t in test)
opt = ADAM(params(scanner, encoder))
evalcb = () -> @show testloss()

(::#21) (generic function with 1 method)

In [8]:
@time for i = 1:10 Flux.train!(loss, train, opt, cb = throttle(evalcb, 10)) end

testloss() = 1.6616275f0 (tracked)
testloss() = 1.6077011f0 (tracked)
testloss() = 1.5930064f0 (tracked)
testloss() = 1.6017551f0 (tracked)
testloss() = 1.6097479f0 (tracked)
testloss() = 1.620472f0 (tracked)
testloss() = 1.597518f0 (tracked)
testloss() = 1.6011655f0 (tracked)
testloss() = 1.5762886f0 (tracked)
testloss() = 1.585687f0 (tracked)
testloss() = 1.5743716f0 (tracked)
testloss() = 1.5689309f0 (tracked)
testloss() = 1.6136243f0 (tracked)
testloss() = 1.5515162f0 (tracked)
testloss() = 1.5636933f0 (tracked)
testloss() = 1.5694954f0 (tracked)
testloss() = 1.4386997f0 (tracked)
testloss() = 1.6454777f0 (tracked)
testloss() = 1.5338187f0 (tracked)
testloss() = 1.5023234f0 (tracked)
testloss() = 1.4564459f0 (tracked)
testloss() = 1.3861965f0 (tracked)
testloss() = 1.5146482f0 (tracked)
testloss() = 1.4520594f0 (tracked)
testloss() = 1.5770667f0 (tracked)
testloss() = 1.5210006f0 (tracked)
testloss() = 1.5022162f0 (tracked)
testloss() = 1.3403647f0 (tracked)
testloss() = 1.347895f0

In [9]:
# open(io -> serialize(io, (langs, alphabet, scanner, encoder)), "model-1.03.jls", "w")

In [10]:
(langs, alphabet, scanner, encoder) = open(deserialize, "model-1.03.jls")

(Any[:en, :it, :fr, :es, :da], ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'  …  '4', '5', '6', '7', '8', '9', ' ', '\n', '_', '\0'], Chain(Dense(40, 15, NNlib.σ), Recur(LSTMCell(15, 60))), Dense(15, 5))

In [11]:
using Interact, Plots

[91mArgumentError: Module Hiccup not found in current path.
Run `Pkg.add("Hiccup")` to install the Hiccup package.[39m


In [12]:
predict(s) =
    isempty(s) ?
        softmax(ones(length(langs))) :
        model(onehotbatch(normalize_string(s, casefold=true, stripmark=true), alphabet, '_').data).data

predict (generic function with 1 method)

In [14]:
@manipulate for s = "c'é una bella filosofia"
    bar(String.(langs), predict(s),
        label=["Probability"], ylims=(0,1))
end