In [1]:
using TextAnalysis, CorpusLoaders, MultiResolutionIterators, LinearAlgebra

┌ Info: Recompiling stale cache file /home/ayushk4/.julia/compiled/v1.0/TextAnalysis/5Mwet.ji for TextAnalysis [a2db99b7-8b79-58f8-94bf-bbc811eef33d]
└ @ Base loading.jl:1190
│ - If you have TextAnalysis checked out for development and have
│   added Libdl as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with TextAnalysis


In [3]:
dataset = CorpusLoaders.load(WikiGold())
dataset = flatten_levels(dataset, lvls(WikiGold, :document)) |> full_consolidate

X = [CorpusLoaders.word.(sent) for sent in dataset]
Y = [TextAnalysis.remove_ner_label_prefix.(CorpusLoaders.named_entity.(sent)) for sent in dataset]
@assert length.(X) == length.(Y)

In [4]:
function try_outs(ner_m, x_in, y_in, eval_func)
    unique_labels = unique(ner.model.labels)
    num_labels = length(unique_labels)
    confusion_matrix = zeros(Int, (num_labels, num_labels))

    for (x_seq, y_seq) in zip(x_in, y_in)

        preds = eval_func(ner_m, x_seq)
        length(preds) != length(y_seq) && continue

        for (pred, logit) in zip(preds, y_seq)
            (logit == "MISC" || pred == "INVALID") && continue
            confusion_matrix[findfirst(x -> x==pred, unique_labels), findfirst(x -> x==logit, unique_labels)] += 1
        end
    end

    s1 = sum(confusion_matrix, dims=2)
    s2 = sum(confusion_matrix, dims=1)'
    dg = diag(confusion_matrix)
    s1 = [s1[1:2]..., s1[4:5]...]
    s2 = [s2[1:2]..., s2[4:5]...]
    dg = [dg[1:2]..., dg[4:5]...]

    unique_labels = unique(ner.model.labels)
    deleteat!(unique_labels, findfirst(x -> x=="MISC", unique_labels))
    # Don't count MISC
    
    f1s = []

    for (p, r, d, tag) in zip(s1, s2, dg, unique_labels)
        println("For tag `$tag`")
        prec = d/p
        recall = d/r
        f1 = (2 * prec * recall) /(prec + recall)
        println("The precision is $prec")
        println("The recall is $recall")
        println("f1 is $f1")
        println()
        push!(f1s, f1)
    end

    a = sum(dg ./ s1) / length(unique_labels)
    b = sum(dg ./ s2) / length(unique_labels)
    println("Overall Micro f1 for NER (excluding MISC) on CoNLL 2003 is ", (2 * a * b)/ (a + b))
    println("Overall Macro f1 for NER (excluding MISC) on CoNLL 2003 is ", sum(f1s)/ length(f1s))
end

try_outs (generic function with 1 method)

In [7]:
ner = NERTagger()

function eval_ner_tagger(ner_m, x_seq) 
    ner_m(x_seq)
end

eval_ner_tagger (generic function with 1 method)

In [6]:
try_outs(ner, X, Y, eval_ner_tagger)

For tag `ORG`
The precision is 0.7349726775956285
The recall is 0.4121552604698672
f1 is 0.5281413612565445

For tag `O`
The precision is 0.9643743088556144
The recall is 0.990514489194499
f1 is 0.9772696297418036

For tag `PER`
The precision is 0.8627797408716137
The recall is 0.8965728274173806
f1 is 0.8793517406962785

For tag `LOC`
The precision is 0.8078495502861816
The recall is 0.68279198341396
f1 is 0.7400749063670413

Overall Micro f1 for NER (excluding MISC) on CoNLL 2003 is 0.7910397182885844
Overall Macro f1 for NER (excluding MISC) on CoNLL 2003 is 0.7812094095154171


In [8]:
function eval_spacy_tagger(ner_m, x_seq)
    preds = String[]
    ents = ner_m(join(x_seq, " ")).ents

    idx = 1
    i = 1
    while i <= length(x_seq)
        if idx <= length(ents) && x_seq[i] == tokenize(ents[idx].text)[1]
            l = length(tokenize(ents[idx].text))

            for k in 1:l
                pred = ents[idx].label_
                if (pred == "PERSON")
                    push!(preds, "PER")
                elseif ( pred == "LOC")
                    push!(pred == "GPE" || preds, "LOC")
                elseif (pred == "ORG")
                    push!(preds, "ORG")
                else
                    push!(preds, "INVALID")
                end
            end
            i = i + l - 1
            idx += 1
        else
            push!(preds, "O")
        end
        i += 1
    end

    return preds
end


eval_spacy_tagger (generic function with 1 method)

In [9]:
using PyCall, WordTokenizers
spacy = pyimport("spacy")
nlp = spacy.load("en_core_web_sm")

PyObject <spacy.lang.en.English object at 0x7ff965d339b0>

In [10]:
try_outs(nlp, X, Y, eval_spacy_tagger)

For tag `ORG`
The precision is 0.5066469719350074
The recall is 0.5968677494199536
f1 is 0.5480692410119841

For tag `O`
The precision is 0.9679771928661407
The recall is 0.9737877676248916
f1 is 0.970873786407767

For tag `PER`
The precision is 0.7230283911671924
The recall is 0.7514754098360655
f1 is 0.7369774919614149

For tag `LOC`
The precision is 0.5833333333333334
The recall is 0.12944523470839261
f1 is 0.21187427240977885

Overall Micro f1 for NER (excluding MISC) on CoNLL 2003 is 0.6514780566020529
Overall Macro f1 for NER (excluding MISC) on CoNLL 2003 is 0.6169486979477363


In [13]:
nltk = pyimport("nltk")
nltk_chunker = nltk.load(nltk.chunk._MULTICLASS_NE_CHUNKER)
nltk_ner(x) = nltk_chunker._tagger.tag(nltk.pos_tag((x)))

nltk_ner (generic function with 1 method)

In [14]:
function eval_nltk_tagger(ner_m, x_seq) 
    obtain_ner(x) = (x[2]) == "O" ? "O" : (x[2])[3:end]
    preds = obtain_ner.(ner_m(x_seq))

    for i in eachindex(preds)
        preds[i] == "O" && continue

        if preds[i] == "PERSON"
            preds[i] = "PER"
        elseif preds[i] == "ORGANIZATION"
            preds[i] = "ORG"
        elseif preds[i] ∈ ("LOCATION", "GPE")
            preds[i] = "LOC"
        else
            preds[i] = "INVALID"
        end
    end
    return preds
end


eval_nltk_tagger (generic function with 1 method)

In [15]:
try_outs(nltk_ner, X, Y, eval_nltk_tagger)

For tag `ORG`
The precision is 0.6182669789227166
The recall is 0.41015018125323666
f1 is 0.4931506849315068

For tag `O`
The precision is 0.9742044869659996
The recall is 0.9878730197715829
f1 is 0.9809911434276918

For tag `PER`
The precision is 0.5979284369114878
The recall is 0.7772337821297429
f1 is 0.6758914316125598

For tag `LOC`
The precision is 0.6344950848972297
The recall is 0.501412429378531
f1 is 0.5601577909270218

Overall Micro f1 for NER (excluding MISC) on CoNLL 2003 is 0.6871963551740787
Overall Macro f1 for NER (excluding MISC) on CoNLL 2003 is 0.6775477627246951
