In [1]:
import os
import math

In [2]:
def read_pred(fpred):
    id2pred = {}
    with open(fpred) as f:
        for line in f:
            if line[:2] == "H-":
                idx, score, pred = line.split("\t")
                idx = int(idx.strip().split("-")[-1])
                pred = pred.strip()
                if idx not in id2pred:
                    id2pred[idx] = pred
    return id2pred

def read_gold(fgold):
    id2gold = {}
    idx = 0
    with open(fgold) as f:
        for line in f:
            id2gold[idx] = line.strip()
            idx += 1
    return id2gold

def eval_file(fpred, fgold):
    id2pred = read_pred(fpred)
    id2gold = read_gold(fgold)
    guess = 0
    correct = 0
    assert len(id2pred) == len(id2gold)
    for idx, gold in id2gold.items():
        guess += 1
        if gold == id2pred[idx]:
            correct += 1
    acc = round(100*correct/guess, 2)
    return acc, correct, guess

In [3]:
def _get_squared_dif(acclist, avg):
    sum_dif = 0
    for acc in acclist:
        dif = acc - avg
        sum_dif += dif * dif
    return sum_dif

def get_avg_std(acclist):
    avg = sum(acclist)/len(acclist)
    std = math.sqrt(_get_squared_dif(acclist, avg)/3)
    return round(avg, 2), round(std, 2)

In [4]:
# langlist = ["czech", "finnish", "german", "russian", "spanish", "turkish"]
langlist = ["aka", "gaa", "lin", "nya", "sot", "swa"]

splittype = "test"
# splittype = "dev"

# seedlist = [str(i)] for i in range(1, 6)]
foldlist = [str(i) for i in range(1, 6)] 

In [5]:
datadir = "PATH_TO_DIRECTORY_WHERE_THE_GROUND_TRUTH_IS_STORED_IN_FAIRSEQ_FORMAT"
preddir = "PATH_TO_CHECKPOINTS_WHERE_INFLECTION_PREDICTIONS_ARE_STORED"

In [7]:
for lang in langlist:
    acclist = []
#     for seed in seedlist:
    for fold in foldlist:
        datadir_now = os.path.join(datadir, "data_niger_congo" + fold)
        langname = lang + fold
        fgold = os.path.join(datadir_now, langname+"/"+splittype+"."+langname+".output")
        fpred = os.path.join(preddir, langname+"-predictions"+fold+"/"+splittype+"-checkpoint_best.pt.txt")
        if os.path.exists(fpred):
            acc, correct, guess = eval_file(fpred, fgold)
            acclist.append(acc)
        else:
            print("NOT EXISTS:", fpred)
            acc, correct, guess = "-", "-", "-"
    avg, std = get_avg_std(acclist)
    print("{} +/- {}".format(avg, std))
#     print(acclist)

41.18 +/- 18.8
49.24 +/- 11.88
60.83 +/- 16.7
84.99 +/- 5.42
6.74 +/- 8.62
45.26 +/- 18.54
