In [None]:
### this functions are not used for processing one scan;
# the function subtractCapThickness can be used to subtract the cap thickness from the reference points but needs to be modified first
# the function alignPoints is used in the function processAll
# the function processALL is used for processing a set of scans for one subject; it is used for aligning the scans to one of them and 
# determining the repeatability of the method


#this function is not used but can be implemented if needed
def subtractCapThickness(outReference, numEl, readFile,  order, capThickness, subtract=["Cz", "Iz", "Rpa", "Lpa"]):
    '''
    #subtract sticekrs where are placed on the cap (usually Cz, Iz, Rpa and Lpa)
    input:
        outReference: array with the points on the head surface (numEl, 3)
        numEl: number of points (scanned reference points or optode points) (int)
        readFile: name of the file to read the plane normals from (string with {} to be replaced by the number of the cluster)
        order: order of the scanned reference points in outReference (5, 3) #example: [4, 0, 2, 3, 1]->[Nz, Iz, Rpa, Lpa, Cz] Nz is in the 4th line in outReference...
        capThickness: thickness of the cap [mm] (int)
        subtract: list of labels of the scanned reference points to subtract (list of strings)
    output:
        outReferenceNew: array with the points including subtracted cap thickness (numEl, 3)

    '''
    allLabelsReference=["Nz", "Iz", "Rpa", "Lpa", "Cz"]
    outReferenceNew=np.zeros((numEl, 3))
    #get normal of plane and cetner of cluster
    for num, label in enumerate(allLabelsReference):
        if label in subtract:
            data = (readFile.format(num+1))
            with open(data, 'r') as csvfile:
                datareader = csv.reader(csvfile)
                for row in datareader:
                    normal, centerProj=row
            normal=np.array(normal[1:-1].split(), dtype=float)
            centerProj=np.array(centerProj[1:-1].split(), dtype=float)
            #depending on the position of the plane, reorient the vector
            if normal[3]>0:
                v1=np.array([-normal[0], -normal[1], -normal[2]])
            else:
                v1=np.array([normal[0], normal[1], normal[2]])
            #subtract cap thickness from center of cluster
            outReferenceNew[num, :]=centerProj-v1*capThickness
        else:
            outReferenceNew[num, :]=outReference[order[num], :]
    return outReferenceNew

#used in the function processAll
def alignPoints(numScannedOptodes,  orderedReference, outOptodes, bestScanRefPoints, bestScanOptodes):
    '''
    input:
        numScannedOptodes: number of scanned optodes points
        orderedReference: ordered scanned reference points (5, 3)->order: Nz, Iz, Rpa, Lpa, Cz
        outOptodes: scanned optodes points (numScannedOptodes, 3)
        bestScanRefPoints: reference points for alignment from the selected best scan (5, 3)
        bestScanOptodes: optode points from the selected best scan (numScannedOptodes, 3)
    output:
        trOrderedOptodes: transformed scanned optodes points: to fit scanned reference points from selected best scan (numScannedOptodes, 3) -> order is determined by points in bestScanOptodes input
        trOrderedReference: transformed ordered scanned reference points: to fit reference points from selected best scan (5, 3)->order: Nz, Iz, Rpa, Lpa, Cz
        errLabels: error of each trOrderedOptodes point to closest point in bestScanOptodes input (1, numScannedOptodes)->order is determined by points in bestScanOptodes input   
    '''
    #align scanned reference points to reference points from selected best scan
    myR, myT = rigidTransform3D(orderedReference.T, bestScanRefPoints.T)
    #transform scanned reference and optode points
    trOrderedReference=myR@orderedReference.T+myT
    trOrderedReference=trOrderedReference.T
    trOrderedOptodes=myR@outOptodes.T+myT
    trOrderedOptodes=trOrderedOptodes.T

    err=np.zeros((numScannedOptodes, numScannedOptodes))
    err2=np.zeros((numScannedOptodes, 2))
    #calculate distance to all scanned optode points from bestScanOptodes input
    for i in range(numScannedOptodes):
        for j in range(numScannedOptodes):
            err[i, j]=np.sqrt((trOrderedOptodes[i, 0]-bestScanOptodes[j, 0])**2+(trOrderedOptodes[i, 1]-bestScanOptodes[j, 1])**2+(trOrderedOptodes[i, 2]-bestScanOptodes[j, 2])**2)
        #find the closest point and index of the point it is closest to
        err2[i, 0]=min(err[i, :])
        err2[i, 1]=np.argmin(err[i, :])
    #check if each point is closest to a different point
    if len(err2[:, 1]) != len(set(err2[:, 1])):
        print('Error: some points are closest to the same point')
    #same order of points as in bestScan
    orderedTrOptodes=np.zeros_like(trOrderedOptodes)
    errLabels=np.zeros(numScannedOptodes)
    for i in range(numScannedOptodes):
        orderedTrOptodes[i, :]=trOrderedOptodes[np.where(err2[:, 1]==i), :]
        errLabels[i]=err2[np.where(err2[:, 1]==i), 0]

    return trOrderedOptodes, trOrderedReference, errLabels

def processAll(subject, masks, scans, bestScan, numScannedOptodes, numScannedReference, radiusOptodes, radiusReference, myPath, subtractOptodeSize, subtractReferenceSize, distance2Plane, twoPoints="Nz", 
               printErr=True, plot=False, idxArr=[], incorrectRecognition=[], splitFirst=[]):
    """
    used for processing a series of scans for one subject - aligning the scans to one of them
    running script needs to be in a folder containing folders with subjects' names; each subject folder contains folders with scans numbered 0 to scans-1;
    each scan folder contains a .obj and a .jpg file with the same name;

    input:
        subject: subject name (string)
        masks: masks for stickers, used as reference points and on top of optodes (list of 12 floats between 0 and 1)
        scans: number of scans (int)
        bestScan: best scan used for alignment of others, reference stickers are well seen (int)
        numScannedOptodes: number of scanned optodes points (int)
        numScannedReference: number of scanned reference points (int)
        radiusOptodes: size of the circle to fit to the optode clusters [mm] (float)
        radiusReference: size of the circle fit to the reference clusters [mm] (float)
        myPath: path to the folder containing the files (WindowsPath)
        subtractOptodeSize: size to subtract from the optode cluster to get wanted point on the scalp [mm] (float)
        subtractReferenceSize: size to subtract from the reference cluster to get wanted point [mm] (float)
        distance2Plane: distance from fitted plane to keep points in clusters [mm] (float)
        two points: label where there are two stickers together (string "Nz" or "Iz")
        printErr: print mean and std of distance from aligned points to points in Standard_Optodes.txt (bool)
        plot: plot the last alignment (bool)
        idxArr: if the ordere is set wrong by the algorithm, the user can input the order; order of the scanned reference points (list of 5 ints)
        incorrectRecognition: if the order is set wrong by the algorithm, the user can input the order; which scan is incorrect (list of ints)
        splitFirst: if the 2 reference stickers are put too close, split first cluster (list)
    output:
        orderedOptodes: transformed ordered scanned optode points to fit reference points from selected best scan (numScannedOptodes, 3, scans)
        orderedReference: transformed ordered scanned reference points to fit reference points from selected best scan (numScannedReference, 3, scans)
        errLabels: error with same order as orderedOptodes from the selected best scan [mm] (numScannedOptodes, scans)

    example: 
        trOrderedOptodes, trOrderedReference, errLabels=processAll("Filip", [0.14 0.035, 0.65, 0.35, 0.8, 0.3, 0.35, 0.1, 0.45, 0.35, 0.45, 0.15], 8, 2, 61, 6, 
        6.5, [5, 5, 5, 5, 5, 5, 5, 5], pathlib.Path().resolve(), 22.6, 0, 2, "Iz")
    """

    #color mask
    yHueCenter=masks[0]
    yHueWidth=masks[1]
    ySatCenter=masks[2]
    ySatWidth=masks[3]
    yValueCenter=masks[4]
    yValueWidth=masks[5]

    bHueCenter=masks[6]
    bHueWidth=masks[7]
    bSatCenter=masks[8]
    bSatWidth=masks[9]
    bValueCenter=masks[10]
    bValueWidth=masks[11]

    trOrderedOptodes=np.zeros((numScannedOptodes, 3, scans))
    trOrderedReference=np.zeros((5, 3, scans))
    errLabels=np.zeros((numScannedOptodes, scans))
    step=0
    for scan in range(scans):
        #first iteration goes through the best scan indicated by the input and is used as refence for other scans
        if scan==0:
            scan=bestScan
        elif scan<=bestScan:
            scan=scan-1
        else:
            scan=scan
        #reset number of scanned reference points
        numRefenceAlt=numScannedReference
        #file to write and read clusters and planes; the second {} gets filled in in the function
        readFileClustersReference=subject+"/scan{scan}/pointsReference{num}.ply".format(scan=scan, num={})
        readFileClustersOptode=subject+"/scan{scan}/pointsOptode{num}.ply".format(scan=scan, num={})
        readFilePlaneReference=subject+"/scan{scan}/normals(+d)Reference{num}.csv".format(scan=scan, num={})
        readFilePlaneOptode=subject+"/scan{scan}/normals(+d)Optode{num}.csv".format(scan=scan, num={})
        #get vertex locations and colors
        vnew, vcolors_rgb, vcolors_hsv=preProcessing(subject+"/scan{}/scan{}post".format(scan, scan))
        #define mask
        MaskOptodes=(np.abs(vcolors_hsv[:,0] - yHueCenter) < yHueWidth) & (np.abs(vcolors_hsv[:,1] - ySatCenter) < ySatWidth) & (np.abs(vcolors_hsv[:,2] - yValueCenter) < yValueWidth)
        MaskReference=(np.abs(vcolors_hsv[:,0] - bHueCenter) < bHueWidth) & (np.abs(vcolors_hsv[:,1] - bSatCenter) < bSatWidth) & (np.abs(vcolors_hsv[:,2] - bValueCenter) < bValueWidth)
        #get points in mask
        newPointsOptodes=pointsiInMask(vnew, MaskOptodes)
        newPointsReference=pointsiInMask(vnew, MaskReference)
        #if the recognition of the ordered points is incorrect, user should input which scan is wrong and the order of the scanned reference points; this happened a couple times where another cluster had more detected points in the mask than scanned reference points
        if scan in incorrectRecognition:
            numRefenceAlt=np.max(idxArr[step])+1
            orderedReference=np.zeros((5, 3))
        #get as many clusters as there are scanned optode or reference points
        finalClustersOptodes=finalClusters(newPointsOptodes, eps=radiusOptodes, minSamples=50, numEl=numScannedOptodes)
        finalClustersReference=finalClusters(newPointsReference, eps=radiusReference[scan], minSamples=50, numEl=numRefenceAlt)
        #in 2 sets of scans, the 6th reference sticker was too close to another one and the clusters were merged
        #6th sticker is used to determine one of the points, followed by the rest
        if scan in splitFirst:            
            finalClustersReference0=finalClusters(finalClustersReference[0], eps=3, minSamples=50, numEl=2)
            finalClustersReference[5]=finalClustersReference[4]
            finalClustersReference[4]=finalClustersReference[3]
            finalClustersReference[3]=finalClustersReference[2]
            finalClustersReference[2]=finalClustersReference[1]
            finalClustersReference[1]=finalClustersReference0[1]
            finalClustersReference[0]=finalClustersReference0[0]
        #write points belonging to each cluster to a file
        writePoints(finalClustersOptodes, readFileClustersOptode)
        writePoints(finalClustersReference, readFileClustersReference)
        #fit a plane to each cluster
        getPlane(finalClustersOptodes, myPath, numScannedOptodes, readFileClustersOptode, readFilePlaneOptode, diameter=2*radiusOptodes)
        getPlane(finalClustersReference, myPath, numRefenceAlt, readFileClustersReference, readFilePlaneReference, diameter=2*radiusOptodes)
        #cut points too far from the plane and rewrite the points file
        writePointsNearPlane(numScannedOptodes, readFileClustersOptode, readFilePlaneOptode, myPath, distance2Plane)
        writePointsNearPlane(numRefenceAlt, readFileClustersReference, readFilePlaneReference, myPath, distance2Plane)
        #re-fit the plane
        getPlane(finalClustersOptodes, myPath, numScannedOptodes, readFileClustersOptode, readFilePlaneOptode, diameter=2*radiusOptodes)
        getPlane(finalClustersReference, myPath, numRefenceAlt, readFileClustersReference, readFilePlaneReference, diameter=2*radiusOptodes)
        #subtract optode size from each point to get a point on the surface of the head
        outOptodes=subtractOptode(readFilePlaneOptode, numScannedOptodes, subtractOptodeSize)
        outReference=subtractOptode(readFilePlaneReference, numRefenceAlt, subtractReferenceSize)
        #if the order is pre-defined, use it
        if scan in incorrectRecognition:
            for i in range(5):
                orderedReference[i, :]=outReference[idxArr[step][i], :]
            step+=1
        else:
            #order the scanned reference points, use a different function for 6 reference points
            #5 reference points function works for montages, where Iz has the lowest mean distance to closest 5 scanned optode points
            if numScannedReference==5:
                orderedReference, order=orderReferencePoints(numScannedOptodes, numScannedReference, outReference, outOptodes)
            elif numScannedReference==6:
                orderedReference, order=orderReferencePoints6(numScannedReference, outReference, twoPoints)
        #use best scan as reference for alignment for the rest of the scans
        if scan==bestScan:
            bestScanRefPoints=orderedReference
            bestScanOptodes=outOptodes
        #align the points to the best scan reference points, return ordered scanned optode points, ordered scanned reference points and error to closest best scan optode point
        trOrderedOptodes[:, :, scan], trOrderedReference[:, :, scan], errLabels[:, scan]=alignPoints(numScannedOptodes, orderedReference, outOptodes, bestScanRefPoints, bestScanOptodes)
        #print mean and std of error
        if printErr==True:
            print(np.mean(errLabels[:, scan]), np.std(errLabels[:, scan]))
    #plot the last alignment
    if plot==True:
        fig = p.figure(figsize=(10, 10)); 
        ax = fig.add_subplot(1, 2, 1, projection="3d")
        ax.scatter(bestScanRefPoints[:, 0], bestScanRefPoints[:, 1], bestScanRefPoints[:, 2], s=10, color='b', label='best scan reference points')
        ax.scatter(orderedReference[:, 0], orderedReference[:, 1], orderedReference[:, 2], s=10, color='r', label='unaligned last scan reference points')
        ax.legend()
        
        ax=fig.add_subplot(1, 2, 2, projection="3d")
        ax.scatter(bestScanRefPoints[:, 0], bestScanRefPoints[:, 1], bestScanRefPoints[:, 2], s=10, color='b', label='best scan reference points')
        ax.scatter(trOrderedReference[:, 0], trOrderedReference[:, 1], trOrderedReference[:, 2], s=10, color='g', label='aligned last scan reference points')
        ax.legend()

        fig = p.figure(figsize=(10, 10)); 
        ax = fig.add_subplot(1, 2, 1, projection="3d")
        ax.scatter(bestScanOptodes[:, 0], bestScanOptodes[:, 1], bestScanOptodes[:, 2], s=10, color='b', label='best scan optode points')
        ax.scatter(outOptodes[:, 0], outOptodes[:, 1], outOptodes[:, 2], s=10, color='r', label='unaligned last scan optode points')
        ax.legend()

        ax=fig.add_subplot(1, 2, 2, projection="3d")
        ax.scatter(bestScanOptodes[:, 0], bestScanOptodes[:, 1], bestScanOptodes[:, 2], s=10, color='b', label='best scan optode points')
        ax.scatter(trOrderedOptodes[:, 0], trOrderedOptodes[:, 1], trOrderedOptodes[:, 2], s=10, color='g', label='aligned last scan optode points')
        ax.legend()

    return trOrderedOptodes, trOrderedReference, errLabels