In [1]:
using TextAnalysis, CorpusLoaders, MultiResolutionIterators, LinearAlgebra, Statistics

┌ Info: Recompiling stale cache file /home/ayushk4/.julia/compiled/v1.0/TextAnalysis/5Mwet.ji for TextAnalysis [a2db99b7-8b79-58f8-94bf-bbc811eef33d]
└ @ Base loading.jl:1190
│ - If you have TextAnalysis checked out for development and have
│   added Libdl as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with TextAnalysis


In [2]:
dataset = flatten_levels(collect(CorpusLoaders.load(GMB())) , lvls(GMB, :document)) |> full_consolidate

X = [word.(sentence) for sentence in dataset]
Y = [CorpusLoaders.named_entity.(sentence) for sentence in dataset]

9418-element Array{Array{String,1},1}:
 ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "Location", "O", "O", "O", "Person", "Person", "Person", "O"]                             
 ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O"  …  "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"]                                                
 ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O"  …  "O", "O", "O", "O", "O", "O", "O", "Location", "Timex", "O"]                                     
 ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O"  …  "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"]                                                
 ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O"  …  "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"]                                                
 ["O", "O", "O", "O", "O", "O", "Person", "Person", "Person", "O"  …  "O", "O", "O", "O", "O", "O", "Location", "Location", "O", "O"]                   
 ["O", "O", "O", "O", "O", "O", "O", "O", "

In [3]:
ner = NERTagger()

function eval_ner_tagger(ner_m, x_seq) 
    ner_m(x_seq)
end


eval_ner_tagger (generic function with 1 method)

In [4]:
function try_outs(ner_m, x_in, y_in, eval_func)
    unique_labels = unique(ner.model.labels)
    num_labels = length(unique_labels)
    confusion_matrix = zeros(Int, (num_labels, num_labels))

    for (x_seq, y_seq) in zip(x_in, y_in)
        preds = eval_func(ner_m, x_seq)

        for (pred, logit) in zip(preds, y_seq)
            (logit == "MISC" || pred == "INVALID") && continue

            if(logit == "O")
                confusion_matrix[findfirst(x -> x==pred, unique_labels), findfirst(x -> x=="O", unique_labels)] += 1
            elseif(logit == "Location")
                confusion_matrix[findfirst(x -> x==pred, unique_labels), findfirst(x -> x=="LOC", unique_labels)] += 1
            elseif(logit == "Person")
                confusion_matrix[findfirst(x -> x==pred, unique_labels), findfirst(x -> x=="PER", unique_labels)] += 1
            elseif(logit == "Organization")
                confusion_matrix[findfirst(x -> x==pred, unique_labels), findfirst(x -> x=="ORG", unique_labels)] += 1
            else
                continue
            end
        end
    end

#     print(confusion_matrix)
    s1 = sum(confusion_matrix, dims=2)
    s2 = sum(confusion_matrix, dims=1)
    dg = diag(confusion_matrix)
    s1 = [s1[1:2]..., s1[4:5]...]
    s2 = [s2[1:2]..., s2[4:5]...]
    dg = [dg[1:2]..., dg[4:5]...]

    unique_labels = unique(ner.model.labels)
    deleteat!(unique_labels, findfirst(x -> x=="MISC", unique_labels))
    # Don't count MISC
    
    f1s = []

    for (p, r, d, tag) in zip(s1, s2, dg, unique_labels)
        println("For tag `$tag`")
        prec = d/p
        recall = d/r
        f1 = (2 * prec * recall) /(prec + recall)
        println("The precision is $prec")
        println("The recall is $recall")
        println("f1 is $f1")
        println()
        push!(f1s, f1)
    end

    a = sum(dg ./ s1) / length(unique_labels)
    b = sum(dg ./ s2) / length(unique_labels)
    println("Overall Micro f1 for NER (excluding MISC) on CoNLL 2003 is ", (2 * a * b)/ (a + b))
    println("Overall Macro f1 for NER (excluding MISC) on CoNLL 2003 is ", sum(f1s)/ length(f1s))
    
#     a = mean(dg ./ s1)
#     b = mean(dg ./ s2)

#     println("Precision and recall are:", a, " ", b)
#     println("F1 is:", (2 * a * b) / (a + b))
end


try_outs (generic function with 1 method)

In [5]:
try_outs(ner, X, Y, eval_ner_tagger)

For tag `ORG`
The precision is 0.7026740799576383
The recall is 0.38519593613933234
f1 is 0.4976094497046968

For tag `O`
The precision is 0.9789615040286481
The recall is 0.974991755523799
f1 is 0.9769725972070936

For tag `PER`
The precision is 0.7854508196721312
The recall is 0.8264338076757223
f1 is 0.8054213069972682

For tag `LOC`
The precision is 0.8240566037735849
The recall is 0.7900508762012436
f1 is 0.8066955266955267

Overall Micro f1 for NER (excluding MISC) on CoNLL 2003 is 0.7815047090242612
Overall Macro f1 for NER (excluding MISC) on CoNLL 2003 is 0.7716747201511462


In [6]:
using PyCall, WordTokenizers
spacy = pyimport("spacy")
nlp = spacy.load("en_core_web_sm")

function eval_spacy_tagger(ner_m, x_seq)
    preds = String[]
    ents = ner_m(join(x_seq, " ")).ents

    idx = 1
    i = 1
    while i <= length(x_seq)
        if idx <= length(ents) && x_seq[i] == tokenize(ents[idx].text)[1]
            l = length(tokenize(ents[idx].text))

            for k in 1:l
                pred = ents[idx].label_
                if (pred == "PERSON")
                    push!(preds, "PER")
                elseif (pred == "GPE" ||  pred == "LOC")
                    push!(preds, "LOC")
                elseif (pred == "ORG")
                    push!(preds, "ORG")
                else
                    push!(preds, "INVALID")
                end
            end
            i = i + l - 1
            idx += 1
        else
            push!(preds, "O")
        end
        i += 1
    end

    return preds
end

eval_spacy_tagger (generic function with 1 method)

In [7]:
try_outs(nlp, X, Y, eval_spacy_tagger)

For tag `ORG`
The precision is 0.5882756256427837
The recall is 0.5323406235458352
f1 is 0.5589121407051543

For tag `O`
The precision is 0.9730422627751332
The recall is 0.9826380531721819
f1 is 0.9778166164999839

For tag `PER`
The precision is 0.7275661717236928
The recall is 0.7420983318700615
f1 is 0.7347604042160166

For tag `LOC`
The precision is 0.7454957960763379
The recall is 0.6581831035701662
f1 is 0.6991239048811013

Overall Micro f1 for NER (excluding MISC) on CoNLL 2003 is 0.7434068789865743
Overall Macro f1 for NER (excluding MISC) on CoNLL 2003 is 0.7426532665755641


In [8]:
nltk = pyimport("nltk")
nltk_chunker = nltk.load(nltk.chunk._MULTICLASS_NE_CHUNKER)
nltk_ner(x) = nltk_chunker._tagger.tag(nltk.pos_tag((x)))

function eval_nltk_tagger(ner_m, x_seq) 
    obtain_ner(x) = (x[2]) == "O" ? "O" : (x[2])[3:end]
    preds = obtain_ner.(ner_m(x_seq))

    for i in eachindex(preds)
        preds[i] == "O" && continue

        if preds[i] == "PERSON"
            preds[i] = "PER"
        elseif preds[i] == "ORGANIZATION"
            preds[i] = "ORG"
        elseif preds[i] ∈ ("LOCATION", "GPE")
            preds[i] = "LOC"
        else
            preds[i] = "INVALID"
        end
    end
    return preds
end


eval_nltk_tagger (generic function with 1 method)

In [9]:
try_outs(nltk_ner, X, Y, eval_nltk_tagger)

For tag `ORG`
The precision is 0.656305114638448
The recall is 0.44162587153241356
f1 is 0.5279772989270196

For tag `O`
The precision is 0.984959026738824
The recall is 0.9771382151560868
f1 is 0.9810330342852588

For tag `PER`
The precision is 0.5971675845790716
The recall is 0.8194774346793349
f1 is 0.6908793009284544

For tag `LOC`
The precision is 0.5963878326996198
The recall is 0.7179311133997025
f1 is 0.6515395399553455

Overall Micro f1 for NER (excluding MISC) on CoNLL 2003 is 0.7235561475388863
Overall Macro f1 for NER (excluding MISC) on CoNLL 2003 is 0.7128572935240195
