# Implementation

In [1]:
import Pkg; Pkg.activate("."); Pkg.add("Transformers"); Pkg.add("Pickle"); Pkg.add("Flux");

[32m[1m  Activating[22m[39m project at `/content`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `/content/Project.toml`
[32m[1m  No Changes[22m[39m to `/content/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `/content/Project.toml`
[32m[1m  No Changes[22m[39m to `/content/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `/content/Project.toml`
[32m[1m  No Changes[22m[39m to `/content/Manifest.toml`


In [2]:
using Transformers
using Transformers.TextEncoders
using Transformers.HuggingFace
using Flux
using Pickle

In [3]:
function copy_with_modification(original::T, field_to_change, new_value) where {T}
    val(field) = field == field_to_change ? new_value : getfield(original, field)
    T(val.(fieldnames(T))...)
end

copy_with_modification (generic function with 1 method)

In [4]:
const query_marker_token = 2
const document_marker_token = 3
const cls_token = 102
const sep_token = 103
const mask_token = 104
const pad_token  = 1

1

In [5]:
struct ColBERT
  bert::Transformers.HuggingFace.HGFBertModel
  linear::Dense
end

function l2_normalization(x::Array)
  l2_norm = sqrt.(sum(x.^2, dims=1))

  return x./l2_norm
end

function encode_query(query::AbstractString)
    query = ". " * query
    original_tokens = encode(textencoder, query).token
    tokens = copy_with_modification(original_tokens, :onehots, setindex!(original_tokens.onehots, query_marker_token, 2))
    return (; token = tokens)
end

function encode_document(doc::AbstractString)
    doc = ". " * doc
    original_tokens = encode(textencoder, doc).token
    tokens = copy_with_modification(original_tokens, :onehots, setindex!(original_tokens.onehots, document_marker_token, 2))
    return (; token = tokens)
end

function (m::ColBERT)(query_ids, doc_ids)
    Q = m.bert(query_ids).hidden_state
    Q = m.linear(Q)
    Q = l2_normalization(Q)
    Q_unsqueezed = reshape(Q, size(Q)[1], 1, size(Q)[2], size(Q)[3])

    D = m.bert(doc_ids).hidden_state
    D = m.linear(D)
    D = l2_normalization(D)
    D_unsqueezed = reshape(D, size(D)[1], size(D)[2], 1, size(D)[3])


    squared_diff =  (Q_unsqueezed .- D_unsqueezed).^2
    summed_squared_diff = 2 .- sum(squared_diff, dims=1)

    result =  summed_squared_diff
    max_values = maximum(result, dims=2)

    return sum(max_values, dims=3)[1]
end

In [6]:
textencoder, bert_model = hgf"colbert-ir/colbertv2.0"
colbert_parameters = Pickle.Torch.THload("pytorch_model.bin")
linear_layer = Dense(Matrix(colbert_parameters["linear.weight"]))

model = ColBERT(bert_model, linear_layer)

ColBERT(HGFBertModel(Chain(CompositeEmbedding(token = Embed(768, 30522), position = ApplyEmbed(.+, FixedLenPositionEmbed(768, 512)), segment = ApplyEmbed(.+, Embed(768, 2), Transformers.HuggingFace.bert_ones_like)), DropoutLayer<nothing>(LayerNorm(768, ϵ = 1.0e-12))), Transformer<12>(PostNormTransformerBlock(DropoutLayer<nothing>(SelfAttention(MultiheadQKVAttenOp(head = 12, p = nothing), Fork<3>(Dense(W = (768, 768), b = true)), Dense(W = (768, 768), b = true))), LayerNorm(768, ϵ = 1.0e-12), DropoutLayer<nothing>(Chain(Dense(σ = NNlib.gelu, W = (768, 3072), b = true), Dense(W = (3072, 768), b = true))), LayerNorm(768, ϵ = 1.0e-12))), Branch{(:pooled,) = (:hidden_state,)}(BertPooler(Dense(σ = NNlib.tanh_fast, W = (768, 768), b = true)))), Dense(768 => 128))

# Examples

In [7]:
query = "what is Julia language"
documents=["Julia is a greate language",
           "I don't know",
           "Julia is my sister; she helpes with cleaning my room",
           "Harry Potter is a series of seven fantasy novels written by J. K. Rowling"]

Q = encode_query(query)
D = [encode_document(doc) for doc in documents]


println(model(Q, D[1]))
println(model(Q, D[2]))
println(model(Q, D[3]))
println(model(Q, D[4]))

12.528865
2.7685688
7.5746684
3.380197


In [8]:
query = "What is Julia language"
documents=["Julia is a greate language",
           "I don't know",
           "Julia is my sister; she helpes with cleaning my room",
           "Harry Potter is a series of seven fantasy novels written by J. K. Rowling"]

scores = zeros(length(documents))
encoded_query = encode_query(query)
encoded_documens = [encode_document(doc) for doc in documents]

for (i, encoded_doc) in enumerate(encoded_documens)
    scores[i] = model(encoded_query, encoded_doc)
end

document_pairs = collect(zip(1:length(scores), scores))

document_order = sort(document_pairs, by = x -> x[2], rev = true)[:, 1]
ordered_documents = [(documents[i], score) for (i, score) in document_order]

display(ordered_documents)

4-element Vector{Tuple{String, Float64}}:
 ("Julia is a greate language", 12.528864860534668)
 ("Julia is my sister; she helpes with cleaning my room", 7.5746684074401855)
 ("Harry Potter is a series of seven fantasy novels written by J. K. Rowling", 3.380197048187256)
 ("I don't know", 2.768568754196167)