In [1]:
using BSON: @load
using Random: randperm
using LinearAlgebra: dot
using Printf: @printf
@load "data.bson" tokens train_data test_data

In [2]:
T = Vector{Vector{Vector{Int}}}
train_data = convert(T, train_data)
test_data = convert(T, test_data);

In [3]:
struct Embedding <: AbstractMatrix{Float32}
    data::Matrix{Float32}
    ∑Δ²::Matrix{Float32}
end

In [4]:
ϵ = 1.0f-8
Embedding(A::Matrix) = Embedding(A, fill(ϵ, size(A)))
Embedding(emb_size, vocab_size; range::Float32 = 0.001f0) =
    Embedding(randn(Float32, emb_size, vocab_size) .* range)

Embedding

In [5]:
Base.size(A::Embedding) = size(A.data)
Base.getindex(A::Embedding, I::Vararg{Int, 2}) = getindex(A.data, I...)
Base.setindex!(A::Embedding, v, I::Vararg{Int, 2}) = setindex!(A.data, v, I...)

In [6]:
(m::Embedding)(idx::Vector) = 
    @inbounds @views sum(m[:, idx], dims = 2)

In [7]:
function update!(emb::Embedding, idx::Vector{Int}, Δ::Matrix{Float32}, η::Float32)
    p = @inbounds @view emb[:, idx]
    g = @inbounds @view emb.∑Δ²[:, idx]
    @. g += Δ^2
    @. p -= Δ * η / (√g + ϵ)
    nothing
end

update! (generic function with 1 method)

In [8]:
function backward_hinge(u::Matrix{Float32}, v::Matrix{Float32}, v̂::Matrix{Float32}, γ = 1.0f0)
    loss = γ - dot(u, v) + dot(u, v̂)
    if loss > 0
        return (v̂ .- v, -u, u)
    end
end

backward_hinge (generic function with 2 methods)

In [9]:
top_k(vec, k) = 
    sortperm(vec, rev = true, alg = PartialQuickSort(k))[1:k]

top_k (generic function with 1 method)

In [10]:
function recall_at_k(emb, data, k = 10)
    n_test = length(data)
    
    descr_emb = Array{Float32}(undef, size(emb, 1), n_test)
    for i in 1:n_test
        descr_emb[:, i] = emb(data[i][2])
    end
    
    recall = count(1:n_test) do i
        title_emb = emb(data[i][1])
        i in top_k(vec(title_emb' * descr_emb), k)
    end
    
    recall / n_test
end

recall_at_k (generic function with 2 methods)

In [11]:
function train!(emb::Embedding, idx::NTuple{3, Vector{Int}}, η::Float32)
    embs = emb.(idx)
    Δs = backward_hinge(embs...)
    isnothing(Δs) && return nothing
    foreach(idx, Δs) do i, Δ
        update!(emb, i, Δ, η)
    end
end

train! (generic function with 1 method)

In [12]:
function train!(emb::Embedding, data::T, η::Float32)
    first = randperm(length(data))    
    second = circshift(first, 1)
    @inbounds for (f, s) in zip(first, second)
        u, v = data[f]
        rand(Bool) ? setdiff!(v, u) : setdiff!(u, v)
        train!(emb, (u, v, data[s][2]), η)
    end
end

train! (generic function with 2 methods)

In [13]:
function train!(emb::Embedding, train_data::T, test_data::T, n_epochs::Int, η::Float32)
    for epoch in 1:n_epochs
        t = @elapsed train!(emb, train_data, η)
        recall = recall_at_k(emb, test_data)
        @printf "Epoch %2i (%1.1fs): recall = %1.2f\n" epoch t recall
    end
end

train! (generic function with 3 methods)

In [14]:
emb = Embedding(256, length(tokens))
train!(emb, train_data, test_data, 20, 1f0)

Epoch  1 (1.3s): recall = 0.08
Epoch  2 (0.4s): recall = 0.08


Epoch  3 (0.3s): recall = 0.08
Epoch  4 (0.3s): recall = 0.06


Epoch  5 (0.2s): recall = 0.06
Epoch  6 (0.2s): recall = 0.07


Epoch  7 (0.2s): recall = 0.07
Epoch  8 (0.2s): recall = 0.07


Epoch  9 (0.2s): recall = 0.07
Epoch 10 (0.1s): recall = 0.07


Epoch 11 (0.2s): recall = 0.06
Epoch 12 (0.1s): recall = 0.07


Epoch 13 (0.1s): recall = 0.05
Epoch 14 (0.1s): recall = 0.05


Epoch 15 (0.1s): recall = 0.07
Epoch 16 (0.1s): recall = 0.07


Epoch 17 (0.2s): recall = 0.07
Epoch 18 (0.1s): recall = 0.07


Epoch 19 (0.1s): recall = 0.08
Epoch 20 (0.1s): recall = 0.09


In [15]:
function knn(query, k = 10)
    scores = dropdims(emb(indexin(query, tokens))' * emb, dims = 1)
    neighbours = sortperm(scores, alg = PartialQuickSort(k), rev = true)[1:k]
    tokens[neighbours]
end

knn (generic function with 2 methods)

In [20]:
knn(["ситуация"])

10-element Vector{Any}:
 "гораздо"
 "неё"
 "сказал"
 "отправлять"
 "я"
 "мы"
 "рассказал"
 "организовать"
 "участие"
 "сейчас"