In [None]:
def evaluateOrnamentsOnPageInner(ornaments, extractedOrnaments, bookId, pageId):
    bestProposalPerTrueOrnament = dict(map(lambda x: (x, 0), range(len(ornaments))))
    bestTrueOrnamentPerProposal = dict(map(lambda x: (x, 0), range(len(extractedOrnaments))))
    extractedOrnamentsScorePerTrueBox = []
    for i in range(len(extractedOrnaments)):
        ious = []
        for j in range(len(ornaments)):
            iou = getIoU(extractedOrnaments[i], ornaments[j])
            ious.append((iou, j))
            if iou > bestProposalPerTrueOrnament[j]:
                 bestProposalPerTrueOrnament[j] = iou
            if iou > bestTrueOrnamentPerProposal[i]:
                 bestTrueOrnamentPerProposal[i] = iou
                    
        extractedOrnamentsScorePerTrueBox.append((extractedOrnaments[i], sorted(ious, reverse=True)))
    
    results = []
    usedOrnament = set()
    while len(extractedOrnamentsScorePerTrueBox) > 0:
        biggest = extractedOrnamentsScorePerTrueBox[0]
        for elem in extractedOrnamentsScorePerTrueBox:
            while len(elem[1]) > 0 and elem[1][0][1] in usedOrnament:
                elem[1].remove(elem[1][0])

            if len(elem[1]) > 0 and elem[1][0][0] > biggest[1][0][0]:
                biggest = elem

        if len(elem[1]) > 0:
            usedOrnament.add(biggest[1][0][1])
            results.append({
                    'proposal': biggest[0],
                    'trueBox': ornaments[biggest[1][0][1]],
                    'iou': biggest[1][0][0]})
        else:
            results.append({
                    'proposal': biggest[0],
                    'trueBox': None,
                    'iou': 0})

        extractedOrnamentsScorePerTrueBox.remove(biggest)
        
    for i in range(len(ornaments)):
        if not i in usedOrnament:
            results.append({
                    'proposal': None,
                    'trueBox': ornaments[i],
                    'iou': 0})
    
    return {
        'bookId': bookId,
        'pageId': pageId,
        'bestProposalPerTrueOrnament': bestProposalPerTrueOrnament,
        'bestTrueOrnamentPerProposal': bestTrueOrnamentPerProposal,
        'results': results
    }
    
def evaluateOrnamentsOnPage(annotatedPage, predictions, threshold):
    ornaments = list(map(parseOrnament, annotatedPage['ornaments']))
    rawExtractedOrnaments = list(map(parseOrnament, annotatedPage['proposals']))
    
    if len(rawExtractedOrnaments) != len(predictions):
        print('Error, prediction has shape {0} when we expected shape ({1}, 2)'.format(
                predictions.shape, len(rawExtractedOrnaments)))
    
    if len(rawExtractedOrnaments) > 0:
        extractedOrnaments = list(map(lambda x1: x1[0],
                list(filter(lambda x2: x2[1]>threshold, list(zip(rawExtractedOrnaments, predictions))))))
    else:
        extractedOrnaments = []
    
    return evaluateOrnamentsOnPageInner(ornaments,
                                        extractedOrnaments,
                                        annotatedPage['bookId'],
                                        annotatedPage['pageId'])
    

def evaluateOrnamentsOnPageWithoutFilter(annotatedPage):
    ornaments = list(map(parseOrnament, annotatedPage['ornaments']))
    extractedOrnaments = list(map(parseOrnament, annotatedPage['proposals']))
    
    return evaluateOrnamentsOnPageInner(ornaments,
                                        extractedOrnaments,
                                        annotatedPage['bookId'],
                                        annotatedPage['pageId'])

def parseScores(scores):
    dataOrnament = np.array([])
    dataProposal = np.array([])
    for pageScore in scores:
        newDataOrnament = np.array(list(pageScore['bestProposalPerTrueOrnament'].values()))
        newDataProposal = np.array(list(pageScore['bestTrueOrnamentPerProposal'].values()))
        dataOrnament = np.concatenate((dataOrnament, newDataOrnament), axis=0)
        dataProposal = np.concatenate((dataProposal, newDataProposal), axis=0)
        
    return dataOrnament, dataProposal

def plotData(data, logScale=True, includeSteam=True):
    plt.hist(data[data>0],bins=22,range=(-0.05,1.05), log=logScale)
    if includeSteam:
        plt.stem([0], [np.sum(data==0)])
    if logScale:
        plt.semilogy()
        
def printDataInfo(data, threshold):
    numExactlyZero = data[data == 0].size
    print("data size: ", data.size)
    print("higher {0}: {1}".format(threshold, data[data > threshold].size))
    print("lower {0}: {1}".format(threshold, data[data <= threshold].size - numExactlyZero))
    print("exactly 0: ", numExactlyZero)
    
def confidenceInterval(data, confidence, numTails=2):
    n = data.size
    if n == 0:
        return (1, 1)
    
    mean = np.mean(data)
    s = np.std(data)
    if numTails == 2:
        confidence = 2*confidence-1
    
    t = stats.t.ppf(confidence, n-1) # critical value
    if t < 0 or t == inf or isnan(t):
        return (mean, mean)
    
    tsSqrtN = t * s / sqrt(n)
    return (mean - tsSqrtN, mean + tsSqrtN)

def getPrecisionRecall(dataOrnament, dataProposal, iouThreshold, confidence, printValues):
    dataProposalThresholded = dataProposal > iouThreshold
    dataOrnamentThresholded = dataOrnament > iouThreshold
    
    if printValues:
        print('dataProposalThresholded: ', np.sum(dataProposalThresholded))
        print('dataOrnamentThresholded: ', np.sum(dataOrnamentThresholded))
    
    precisionCi = confidenceInterval(dataProposalThresholded, confidence)
    recallCi = confidenceInterval(dataOrnamentThresholded, confidence)
    
    if printValues:
        print("Precision detection: {0}%\n\t{1}".format((precisionCi[0]+precisionCi[1])*50, precisionCi))
        print("Recall ornament: {0}%\n\t{1}".format((recallCi[0]+recallCi[1])*50, recallCi))
    
    return precisionCi, recallCi

def writePrecisionRecall(precision, recall, method):
    filepath = "{0}precisionRecall.json".format(outputFolder)
    if os.path.isfile(filepath):
        precisionRecallJson = open(filepath).read()
        #os.remove(filepath)
        precisionRecall = json.loads(precisionRecallJson)
    else:
        precisionRecall = {}

    curvePoints = precisionRecall.get(method['hash'], [])
    newPrecisionRecallPoint = {
            'precision': precision,
            'recall': recall
        }

    curvePoints.append(newPrecisionRecallPoint)
    precisionRecall[method['hash']] = curvePoints

    methods = precisionRecall.get('methods', {})
    methods[method['hash']] = method
    precisionRecall['methods'] = methods


    jsonFile = open(filepath, "w")
    jsonFile.write(json.dumps(precisionRecall, indent=4, sort_keys=True))
    jsonFile.close()

    
def writeCurvesInJson(curves, name):
    filepath = "{0}curves.json".format(outputFolder)
    if os.path.isfile(filepath):
        curvesJson = open(filepath).read()
        curvesSet = json.loads(curvesJson)
    else:
        curvesSet = {}


    curvesSet[name] = curves
    jsonFile = open(filepath, "w")
    jsonFile.write(json.dumps(curvesSet, indent=4, sort_keys=True))
    jsonFile.close()
    

def computePrecisionRecall(dataOrnament, dataProposal, iouThreshold, showGraths, saveResults, method):
    if showGraths:
        print("Ornaments:")
        plotData(dataOrnament, logScale=False, includeSteam=True)
        printDataInfo(dataOrnament, iouThreshold)
        plt.show()
        print("\n")

        print("Proposals:")
        plotData(dataProposal, logScale=False, includeSteam=True)
        printDataInfo(dataProposal, iouThreshold)
        plt.show()

    precision, recall = getPrecisionRecall(dataOrnament, dataProposal, iouThreshold, 0.995, showGraths)
    
    if saveResults:
        writePrecisionRecall(precision, recall, method)

def getAverageAndEdge(tupleValue):
    return (tupleValue[0], (tupleValue[0]+tupleValue[1])/2, tupleValue[1])

def plotCurve(curve, arg='b-', label=''):
    xValues, yValues = zip(*curve)
    plt.plot(xValues, yValues, arg, label=label)
    
def plotWithInterval(curves, color='b', cut=0, labels=['', '', '']):
    plotCurve(curves[0][cut:], color+'--', label=labels[0])
    plotCurve(curves[1], color+'-', label=labels[1])
    plotCurve(curves[2][cut:], color+'--', label=labels[2])