In [13]:
import numpy as np
from scipy import stats
from pymoo.util.misc import stack
from pymoo.model.problem import Problem

from sklearn.metrics import r2_score

#pd.options.display.precision = 2


class testMyProblem(Problem):

    def __init__(self, hit_chs, miss_chs, non_sat_chs, energies):
        
        self.pwd = "."
        #self.pwd = str(Path().absolute())
        #self.pwd = "/nfs/cuore1/scratch/yocum"
        self.coords = self.load_coords()
        
        self.hit_chs = hit_chs
        self.miss_chs = miss_chs
        self.non_sat_chs = non_sat_chs
        
        self.energies = energies
        
        self.hit_pts = np.array([self.coords[pt] for pt in hit_chs]).T
        self.miss_pts = np.array([self.coords[pt] for pt in miss_chs]).T
        
        boundary = self.boundary_coords()
        
        xl = np.copy(boundary) - 26
        xu = np.copy(boundary) + 26
        
        super().__init__(n_var=6,
                         n_obj=3,
                         n_constr=0,
                         xl=xl,
                         xu=xu)
        

    def _evaluate(self, x, out, *args, **kwargs):
        
        #print(x)
        
        f1 = self.hitcost(x)
        f2 = self.misscost(x)
        f3 = self.lincost(x)

                      
        out["F"] = np.column_stack([f1, f2, f3])
        
        
        
    def boundary_coords(self):
        
        p = self.hit_pts.T.mean(axis=0)
        uu, dd, vv = np.linalg.svd(self.hit_pts.T - p)  
        v = vv[0] / np.linalg.norm(vv[0])
        line = np.append(p, v)
        linepts = self.line_to_pts(line).flatten()
        
        hit_distances = self.ptsfromline(self.hit_pts, linepts.flatten()) 
        indices = [pt for dist,pt in sorted(zip(hit_distances, np.arange(len(hit_distances))))]
        best_pts = np.append(self.hit_pts.T[indices[0]], self.hit_pts.T[indices[1]])
            
        return best_pts
    
        
    #all_linepts = np.tile(linepts.T, (3,1,1))
    def ptsfromline(self, pts, linepts):
        
        a = linepts[np.newaxis].T[:3]
        b = linepts[np.newaxis].T[3:]

        d = np.linalg.norm(np.cross(pts - a, pts-b, axis=0), axis=0) / np.linalg.norm(b-a, axis=0)
        
        return d
    
    def hitcost (self, x): 
    
        costs = []
        
        cubeLength = 50
        inside = cubeLength/2*np.sqrt(3)

        bound = (cubeLength + inside) / 2

        for line in x:

            #p = line[:3]
            #v = line[3:]

            #linepts = v * np.mgrid[-.5:.5:2j][:, np.newaxis] + p
                         
            linepts = line

            hitlist = self.ptsfromline(self.hit_pts, linepts)

            #return sum([d**2 for d in hitlist]) + sum([inside**4/d**2 for d in misslist])

            linecost =  sum([1/(1 + np.exp(-.2*(d-inside))) for d in hitlist]) 
            costs.append(linecost)
            
        return costs

    def misscost (self, x): 

        costs = []
                                   
        cubeLength = 50
        inside = cubeLength/2*np.sqrt(3)

        bound = (cubeLength + inside) / 2

        for line in x:
            #p = line[:3]
            #v = line[3:]

            #linepts = v * np.mgrid[-.5:.5:2j][:, np.newaxis] + p
                         
            linepts = line

            misslist = self.ptsfromline(self.miss_pts, linepts)

            #return sum([d**2 for d in hitlist]) + sum([inside**4/d**2 for d in misslist])

            linecost =  sum([1/(1 + np.exp(.2*(d-25))) for d in misslist])
            costs.append(linecost)
            
        return costs
    
    def lincost (self, x):
        
        costs = []
        
        for linepts in x:
            
            line = self.pts_to_line(linepts)
            hit_channels, _, track_distances = self.channelcollisions(line)
            
            #print(line)
            #print(hit_channels)
            #print(self.hit_chs)
            
            data = []
            for i in range(len(self.hit_chs)):
                if self.hit_chs[i] in hit_channels and self.hit_chs[i] in self.non_sat_chs:
                    data.append((track_distances[np.where(hit_channels == self.hit_chs[i])[0][0]], self.energies[i]))
                #else:
                #    data = []
                #    break   
                    
            data = np.array(data)
            modifier = len(self.non_sat_chs) - len(data) + 1
            
            #print(hit_channels)
            #print(self.hit_chs)
            
            if len(data) in [0,1,2]:
                costs.append(modifier)
                continue
                
            x = data[:,0][:,np.newaxis]            
            y = data[:,1]
            
            slope, _, _, _ = np.linalg.lstsq(x, y, rcond=None)
            m = slope[0]
            r2 = r2_score(y, data[:,0] * m)
                
            
            #slope, intercept, r_value, p_value, std_err = stats.linregress(data[:,0], data[:,1])
                        
            #if r2 == 0:
            #    costs.append(modifier)
            #else:
            #    costs.append((1 / r2) * modifier)
            
            costs.append(-r2 + modifier)
            
            '''
            
            print("len(data): " + str(len(data)))
            print("r2: " + str(r_value**2))
            print("modifier: " + str(modifier))
            print((1 - r_value**2) + modifier)
            print()
            
            '''
            
        return costs
        
    
    def lineplanecollision(self, planeNormal, planePoint, rayDirection, rayPoint, epsilon=1e-6):

        ndotu = planeNormal.dot(rayDirection)
        if abs(ndotu) < epsilon:
            return None

        t = -planeNormal.dot(rayPoint - planePoint) / ndotu

        return rayPoint + t * rayDirection

    
    def linecubecollision(self, cubeCenter, cubeLength, rayDirection, rayPoint, epsilon=1e-6):

        cubeCollisions = []

        halfLength = cubeLength / 2.0

        directions = np.array([
            [0,0,halfLength], #up
            [0,halfLength,0], #front
            [halfLength,0,0], #right
        ])

        planeCollisions = []
        for i in range(6):
            if i >= 3:
                faceNormal = -directions[i%3] # to get down, back, left
            else:
                faceNormal = directions[i]

            facePoint = cubeCenter + faceNormal

            collision = self.lineplanecollision(faceNormal, facePoint, rayDirection, rayPoint)
            if collision is not None:
                planeCollisions.append(collision)

        #check if intersection is outside cube
        for collision in planeCollisions:

            inside = True
            for i in range(3):
                if collision[i] > (cubeCenter[i] + halfLength + epsilon) or collision[i] < (cubeCenter[i] - halfLength - epsilon):
                    inside = False

            if inside:
                cubeCollisions.append(collision)

        return cubeCollisions
            
    
    def channelcollisions(self, line, epsilon=1e-6):
        
        rayDirection = line[3:]
        rayPoint = line[:3]
        
        #rayDirection = linepoints[1] - linepoints[0]
        #rayPoint = linepoints[0]
        cubeLength = 50
                
        #start = time.time()
        
        hit_channels = []
        miss_channels = []
        track_distances = []
        
        for channel in range(1,len(self.coords)+1):
            cubeCenter = self.coords[channel]

            #check if cubeCenter is within range of line
            CP = cubeCenter - rayPoint
            distance_to_line = np.abs(np.linalg.norm(cross(CP,rayDirection)) / np.linalg.norm(rayDirection))

            #print(distance_to_line)

            if distance_to_line < cubeLength/2*np.sqrt(3) + epsilon:
            #if distance_to_line < cubeLength*np.sqrt(3) + epsilon:

                collision = self.linecubecollision(cubeCenter, cubeLength, rayDirection, rayPoint)
                if len(collision) == 2:
                    hit_channels.append(channel)
                    track_distances.append(np.linalg.norm(collision[1] - collision[0]))
                else:
                    miss_channels.append(channel)
                    
        return (hit_channels, miss_channels, track_distances)
    
    
    def pts_to_line(self, line_pts):
        
        p = np.array(line_pts)[:3]
        a = np.array(line_pts)[3:]
        
        v = (a - p) / np.linalg.norm(a - p)
        
        x = np.array([p, v]).flatten()
        
        return x
    
    
    def line_to_pts(self, line):
        x = np.array(line)
        
        p = x[:3]
        v = x[3:]
        return v * np.mgrid[-800:800:2j][:, np.newaxis] + p
    
    
     # create dictionary mapping channel numbers to a tuple containing coordinates (x,y,z)
    def load_coords(self):

        coords = {}

        with open(self.pwd + "/data/detector_positions.txt", 'r') as f:
            for line in f:
                data = line.split(',')

                if int(data[0]) < 1000:
                    coords[int(data[0])] = (float(data[1]), float(data[2]), float(data[3]))

        return coords


    
# manually do crossproduct to avoid numpy overhead for small vectors
def cross(a, b):
    c = [a[1]*b[2] - a[2]*b[1], a[2]*b[0] - a[0]*b[2], a[0]*b[1] - a[1]*b[0]]
    return c

In [None]:
#problem = MyProblem(hit_chs=[214,215,241], miss_chs=[601,602,603], energies=[30,3500,352])

In [None]:
#problem.lincost([np.array([1,1,1,1,1,1]),np.array([2,2,2,2,2,2])])