In [5]:
import numpy as np
import pandas as pd
import scipy as sc
import uproot as up
import matplotlib.pyplot as plt
from datetime import datetime, timezone, timedelta
from scipy import stats
import os
import math
from pathlib import Path
import time

pwd = "."
#pwd = str(Path().absolute())
#pwd = "/nfs/cuore1/scratch/yocum"


class testDataAnalysis:
    def __init__(self, auto=False, load=False, eventfile=None, clusterfile=None): 
        
        self.coords = self.load_coords()
        
        if load:
            self.eventdf = pd.read_csv(eventfile)
            self.clusterdf = pd.read_csv(clusterfile)
        
        elif auto:
            #self.eventdf = self.load_data()
            self.load_data()
            self.filter_noisy()
            self.filter_baseline()
            self.arrange_clusters(5, 1.0) # >= 5 events, <= 1.0 seconds
            self.make_clusterdf()
            self.filter_clusterdf(2.5, 0) # <= 1.5 NRMSE, >= 0 channels
    
    def get_eventdf(self):
        return self.eventdf
    
    def set_eventdf(self, df):
        self.eventdf = df
        
    
    def save_eventdf(self, path):
        self.clusterdf.to_csv(path, index=False)
    
    def save_clusterdf(self, path):
        self.eventdf.to_csv(path, index=False)
        
    
    def load_data(self):
        frames = []
        path = pwd + '/data/ds3564/'
        
        #num_towers = len(os.listdir(path))
        num_towers = 19
        filename = 'ds3564Tower'

        for t in range(1, num_towers + 1):
            new_path = path + filename + str(t) + '.root'

            #load tower
            event = up.open(new_path)['tree']

            #recast the data as a pandas dataframe and append to frames
            frames.append(event.pandas.df())

        raw = pd.concat(frames)
        
        #adjust variable from milli to seconds
        raw['MaxPosInWindow'] = raw['MaxPosInWindow'] / 1000.0
        
        
        run_starttimes = self.runstarttimes()
        
        for run in raw['Run'].unique():
            raw.loc[raw['Run'] == run, ['Time']] += run_starttimes[run]
        
        self.eventdf = raw
        #return raw

    
    def runstarttimes(self):

        times = {}

        # fix times
        with open(pwd + "/data/ds3564_start_stop_times.txt") as f:
            f.readline()
            f.readline()

            first = True

            for line in f:
                linedata = line.split('|')
                linedata = [i.lstrip().rstrip() for i in linedata]

                if linedata[2] == "Background" and linedata[6] == "OK (0)":

                    linedate = datetime.strptime(linedata[3], "%b %d, %Y %H:%M:%S%z")
                    #print(linedate.timestamp())
                    timestamp = linedate.replace(tzinfo=timezone.utc).timestamp()

                    # save first run timestamp
                    if first:
                        first_timestamp = timestamp
                        first = False

                    times[int(linedata[1])] = timestamp - first_timestamp
        return times
    
    
    def eventsperchannel(self):
        # get num events per channel
        events = []
        for c in range(1,max(self.eventdf['Channel']) + 1):
            events.append(len(self.eventdf[self.eventdf['Channel'] == c]))
        return events
    
    
    # find dead channels
    def deadchannels(self):
        channel_events = self.eventsperchannel()
        
        ch = []
        for c in range(1,max(self.eventdf['Channel']) + 1):     
            if channel_events[c - 1] == 0:
                ch.append(c)

        #for c in range()
        return ch
    
    
    #detect outliers using above threshold IQR
    def noisychannels(self):
        threshold = 5 
        
        channel_events = self.eventsperchannel()        
        Q1, Q3 = np.percentile(channel_events, 25), np.percentile(channel_events, 75)
        IQR = Q3 - Q1

        upper_bound = Q3 + IQR * threshold
        return [c for c in range(1,max(self.eventdf['Channel']) + 1) if channel_events[c - 1] > upper_bound]

    
    def filter_noisy(self):
        self.eventdf = self.eventdf[np.isin(self.eventdf['Channel'], self.noisychannels(), invert=True)]
        
        return self
        
        
    def filter_baseline(self):
        self.eventdf = self.eventdf[(self.eventdf['Baseline'] + self.eventdf['MaxToBaseline']) > 9000]
        
        return self
        
    
    def arrange_clusters(self, e_thresh=5.0, t_thresh=1.0):
        
        e_thresh = int(e_thresh)
        t_thresh = float(t_thresh)
        
        sorted_df = self.eventdf.copy()
        
        sorted_df['MaxTime'] =  self.eventdf[['Time', 'MaxPosInWindow']].sum(axis=1) # sort by 'Time' + 'MaxPosInWindow'
        #sorted_df['MaxTime'] =  self.eventdf[['Time', 'OFdelay']].sum(axis=1) # sort by 'Time' + 'MaxPosInWindow'
        #sorted_df['MaxTime'] = self.eventdf['Time'] # Sort by 'Time'
        
        
        sorted_df = sorted_df.sort_values(by=['MaxTime'])
        sorted_df = sorted_df.reset_index(drop=True)

        new_df = sorted_df.copy()

        #print(new_df[50:70])

        new_df['Cluster'] = [-1]*len(new_df)

        #get events that are clustered
        row = 0
        events = 1
        cluster = [row]
        cluster_num = 0

        while (row < len(self.eventdf)):

            #make sure there is a next event. if at end of dataframe, set times to fail next test
            if(row < len(self.eventdf) - 1):
                successive_time = sorted_df.iloc[row + 1]['MaxTime'] #+ sorted_df.iloc[row + 1]['MaxPosInWindow']/1000.0
                event_time = sorted_df.iloc[row]['MaxTime'] #+ sorted_df.iloc[row]['MaxPosInWindow']/1000.0
            else:
                event_time = 0
                successive_time = t_thresh + 1


            if abs(successive_time - event_time) <= t_thresh:
                events += 1
                cluster.append(row + 1)
            else:   

                #print(events)
                if events < e_thresh:
                    for i in cluster:
                        new_df = new_df.drop(i) #sorted_df.index[i])
                else:
                    #clusters.append(cluster)
                    for i in cluster:
                        #print(cluster_num)
                        new_df.loc[i, 'Cluster'] = cluster_num
                    cluster_num += 1

                events = 1
                cluster = [row + 1]

            row += 1
        
        self.eventdf = new_df
    
        return self


    # create dictionary mapping channel numbers to a tuple containing coordinates (x,y,z)
    def load_coords(self):

        coords = {}

        with open(pwd + "/data/detector_positions.txt", 'r') as f:
            for line in f:
                data = line.split(',')

                if int(data[0]) < 1000:
                    coords[int(data[0])] = (float(data[1]), float(data[2]), float(data[3]))

        return coords

    # returns array of 3 arrays corresponding to x y z
    def clustercoords(self, cluster):

        #coords = []
        x = []
        y = []
        z = []

        for c in cluster['Channel']:
            #coords.append([ch_coords[c][0], ch_coords[c][1], ch_coords[c][2]])
            x.append(self.coords[c][0])
            y.append(self.coords[c][1])
            z.append(self.coords[c][2])

        return [x,y,z]

    
    # takes dataframe of a single cluster and finds line of best fit
    def basicfit(self, cluster):
        coords = self.clustercoords(cluster)

        data  = np.array(coords).T

        datamean = data.mean(axis=0)

        # Do an SVD on the mean-centered data.
        uu, dd, vv = np.linalg.svd(data - datamean)

        linepts = vv[0] * np.mgrid[-400:400:2j][:, np.newaxis]

        # shift by the mean to get the line in the right place
        linepts += datamean

        #return linepts
        return (datamean, vv[0])
    
       
    def ptsfromline(self, pts, linepts):
    
        dlist = []

        a = linepts[0]
        b = linepts[1]

        for i in range(len(pts[0])):
            
            p = np.array([pts[0][i], pts[1][i], pts[2][i]])            
            d = np.linalg.norm(cross(p-a, p-b)) / np.linalg.norm(b-a)

            dlist.append(d)

        return dlist

    
    # x = [p0,p1,p2,v0,v1,v2]
    def cost (self, x, hit_pts, miss_pts): 
        cubeLength = 50
        inside = cubeLength/2*np.sqrt(3)

        p = x[:3]
        v = x[3:]

        linepts = v * np.mgrid[-.5:.5:2j][:, np.newaxis] + p

        hitlist = ptsfromline(hit_pts, linepts)

        if len(miss_pts) == 0:
            #print('here')
            return sum([1/(1 + np.exp(-.1*(d-inside))) for d in hitlist])

        misslist = ptsfromline(miss_pts, linepts)

        #return sum([d**2 for d in hitlist]) + sum([inside**4/d**2 for d in misslist])
        
        return sum([1/(1 + np.exp(-.1*(d-inside))) for d in hitlist]) + sum([1/(1 + np.exp(.1*(d-inside))) for d in misslist])


    
    # takes dataframe of a single cluster and finds line of best fit
    def fitline(self, cluster, trials = 20, steps=200, eps=1e-6):
        
        p0, v0 = basicfit(cluster)

        linepts = v0 * np.mgrid[-400:400:2j][:, np.newaxis] + p0

        #cost(np.append(vv[0], datamean))
        #x = np.ones(6)
        x = np.append(p0, v0)

        first_hit_chs, first_miss_chs = self.channelcollisions(linepts)

        want_to_hit_chs = cluster['Channel'].values
        want_to_miss_chs = [ch for ch in first_hit_chs if ch not in want_to_hit_chs]

        for i in range(trials):
            
            linepts = x[3:] * np.mgrid[-400:400:2j][:, np.newaxis] + x[:3]
            hit_chs, miss_chs = self.channelcollisions(linepts)

            for ch in miss_chs:
                if ch not in want_to_hit_chs and ch not in want_to_miss_chs:
                    want_to_miss_chs.append(ch)
                    #print(want_to_miss_chs)


            hit_pts = np.array([da.coords[channel] for channel in want_to_hit_chs]).T
            miss_pts = np.array([da.coords[channel] for channel in want_to_miss_chs]).T

            if len(want_to_miss_chs) == 0:
                return linepts

            for j in range(steps):

                g = sc.optimize.approx_fprime(x, cost, [eps]*3 + [eps]*3, hit_pts, miss_pts)

                #print(g)
                x -= g*eps

        p = x[:3]
        v = x[3:] / np.linalg.norm(x[3:])

        betterlinepts = v * np.mgrid[-400:400:2j][:, np.newaxis] + p
    
        return betterlinepts
    
    
    
    def clusterNRMSE(self, cluster):
        ''' gets NRMSE for a given cluster
            use distance from point to line of best fit as residual where
            d = |(p-a)x(p-b)|/|b-a|
            and variables are vectors'''
        
        if len(cluster) <= 2:
            return 0

        dlist = []
        
        # store 2 best fit lines as vectors
        linepoints = self.fitline(cluster)
        a = linepoints[0]
        b = linepoints[1]

        for index, event in cluster.iterrows():

            p = np.array(self.coords[event['Channel']])
            d = np.linalg.norm(cross(p-a, p-b)) / np.linalg.norm(b-a)

            dlist.append(d)

        # get root mean squared error for cluster
        RMSE = math.sqrt(sum([i**2 for i in dlist])/(2* len(dlist) - 4))

        # normalize
        NRMSE = RMSE / 4.54**2

        return NRMSE
    
    def errorchannels(self,cluster):
        
        linepoints = self.fitline(cluster)
        hitchannels = self.channelcollisions(linepoints)[0]
        clusterchannels = cluster['Channel'].unique()
        
        extra = [ch for ch in hitchannels if ch not in clusterchannels]
        missing = [ch for ch in clusterchannels if ch not in hitchannels]
        
        return (missing, extra)
        
        
    def zenith(self, linepoints):

        z = abs(linepoints[0][2] - linepoints[1][2])
        d = abs(np.linalg.norm(linepoints[0] - linepoints[1]))
        return math.acos(z/d)
    
    def azimuth(self, linepoints):

        y = linepoints[0][1] - linepoints[1][1]
        x = linepoints[0][0] - linepoints[1][0]
        
        # confine to 1st and 4th quadrant
        if x < 0:
            y*=-1
            x*=-1
        
        return math.atan2(y, x)
    
        
    def lineplanecollision(self, planeNormal, planePoint, rayDirection, rayPoint, epsilon=1e-6):

        ndotu = planeNormal.dot(rayDirection)
        if abs(ndotu) < epsilon:
            return None

        t = -planeNormal.dot(rayPoint - planePoint) / ndotu

        return rayPoint + t * rayDirection

    
    def linecubecollision(self, cubeCenter, cubeLength, rayDirection, rayPoint, epsilon=1e-6):

        cubeCollisions = []

        halfLength = cubeLength / 2.0

        directions = np.array([
            [0,0,halfLength], #up
            [0,halfLength,0], #front
            [halfLength,0,0], #right
        ])

        planeCollisions = []
        for i in range(6):
            if i >= 3:
                faceNormal = -directions[i%3] # to get down, back, left
            else:
                faceNormal = directions[i]

            facePoint = cubeCenter + faceNormal

            collision = self.lineplanecollision(faceNormal, facePoint, rayDirection, rayPoint)
            if collision is not None:
                planeCollisions.append(collision)

        #check if intersection is outside cube
        for collision in planeCollisions:

            inside = True
            for i in range(3):
                if collision[i] > (cubeCenter[i] + halfLength + epsilon) or collision[i] < (cubeCenter[i] - halfLength - epsilon):
                    inside = False

            if inside:
                cubeCollisions.append(collision)

        return cubeCollisions
    
    
    def channelcollisions(self, linepoints, epsilon=1e-6):
                
        rayDirection = linepoints[1] - linepoints[0]
        rayPoint = linepoints[0]
        cubeLength = 50
                
        #start = time.time()
        
        hit_channels = []
        miss_channels = []
        
        for channel in range(1,len(self.coords)+1):
            cubeCenter = self.coords[channel]

            #check if cubeCenter is within range of line
            CP = cubeCenter - linepoints[0]
            distance_to_line = np.abs(np.linalg.norm(cross(CP,rayDirection)) / np.linalg.norm(rayDirection))

            #print(distance_to_line)

            if distance_to_line < cubeLength/2*np.sqrt(3) + epsilon:
            #if distance_to_line < cubeLength*np.sqrt(3) + epsilon:

                collision = self.linecubecollision(cubeCenter, cubeLength, rayDirection, rayPoint)
                if len(collision) == 2:
                    hit_channels.append(channel)
                else:
                    miss_channels.append(channel)
                    
        return (hit_channels, miss_channels)
       
    
    def make_clusterdf(self):
        # get clusters
        clusters = np.unique(self.eventdf['Cluster'])
        numofclusters = len(clusters)

        eventspercluster = []
        channelspercluster = []    
        starttimes = []
        timespreads = []
        NRMSE = []
        extrachannels = []
        missingchannels = []
        zeniths = []
        azimuths = []
        
        
        for c in range(numofclusters):

            cluster = self.eventdf[self.eventdf['Cluster'] == c]

            #event and channel info
            eventspercluster.append(len(cluster))
            channelspercluster.append(len(cluster['Channel'].unique()))

            #get timespread
            clustertimes = cluster['MaxTime']
            starttimes.append(min(clustertimes))
            timespreads.append(max(clustertimes) - min(clustertimes))

            #get NRMSE
            NRMSE.append(self.clusterNRMSE(cluster))
                        
            #get missing and extra channels
            missing, extra = self.errorchannels(cluster)
            
            extrachannels.append(len(extra))
            missingchannels.append(len(missing))
            
            #get angles
            linepoints = self.fitline(cluster)
            zeniths.append(self.zenith(linepoints))
            azimuths.append(self.azimuth(linepoints))
        
        #zeniths_degrees = [theta*360/(2*math.pi) for theta in zeniths]
        #cos_theta = [math.cos(theta) for theta in zeniths]

        d = {'Cluster' : clusters, 'Events' : eventspercluster, 'Channels' : channelspercluster, \
            'StartTime': starttimes, 'TimeSpread' : timespreads, 'NRMSE' : NRMSE, \
             'Zenith' : zeniths,'Azimuth': azimuths, 'ExtraCh': extrachannels, 'MissingCh' : missingchannels}

        newdf = pd.DataFrame(data=d)

        #return newdf
        self.clusterdf = newdf
        
        return self
        
        
    def filter_clusterdf(self, NRMSE, channels):
        self.clusterdf = self.clusterdf[self.clusterdf['NRMSE'] < NRMSE]
        self.clusterdf = self.clusterdf[self.clusterdf['Channels'] >= channels]
        
        return self


    def get_clusterdf(self):
        return self.clusterdf
    
    
    def get_clusterrate(self):
        
        num_clusters = len(self.clusterdf)

        # fix times
        with open(pwd + "/data/ds3564_start_stop_times.txt") as f:
            f.readline()
            f.readline()

            first = True

            sum = timedelta()

            for line in f:
                linedata = line.split('|')
                linedata = [i.lstrip().rstrip() for i in linedata]

                if linedata[2] == "Background" and linedata[6] == "OK (0)":

                    (h, m, s) = linedata[5].split(':')
                    d = timedelta(hours=int(h), minutes=int(m), seconds=int(s))
                    sum += d

        total_seconds = sum.total_seconds()

        return float(num_clusters) / float(total_seconds)
    
    
    def get_cluster(self, cluster_num):
        return self.eventdf[self.eventdf['Cluster'] == cluster_num]
    
    
    def show_channel(self, channel_list, x1=15, x2=45):
        plt.figure(figsize=(10,10))
        ax = plt.axes(projection='3d')
        ax.set_proj_type('ortho')
        
        if isinstance(channel_list, (int, np.int64)):
            channel_list = [channel_list]
            
        coords = np.array([self.coords[ch] for ch in channel_list]).T
        ax.scatter3D(*coords)
        
        plt.xlim([-350,350])
        plt.ylim([-350,350])
        ax.set_zlim([-350,350])

        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')

        ax.view_init(x1, x2)
        
        plt.show()
        
    
        
    def show_cluster(self, cluster_list, x1=15, x2=45):
        
        plt.figure(figsize=(10,10))
        ax = plt.axes(projection='3d')
        ax.set_proj_type('ortho')
        
        #if given int, make list
        if isinstance(cluster_list, (int, np.int64)):
            cluster_list = [cluster_list]
            
        for c in cluster_list:
            cluster = self.eventdf[self.eventdf['Cluster'] == c]

            coords = self.clustercoords(cluster)
            ax.scatter3D(*coords)

            line = self.fitline(cluster)
            ax.plot3D(*line.T)

        plt.xlim([-350,350])
        plt.ylim([-350,350])
        ax.set_zlim([-350,350])

        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')

        ax.view_init(x1, x2)
        
        plt.show()
        
    def show_simulation(self, cluster_num, x1=15, x2=45):
        
        linepoints = self.fitline(self.get_cluster(cluster_num))
        hit_channels = self.channelcollisions(linepoints)[0]

        hit_channel_coords = np.array([self.coords[channel] for channel in hit_channels])


        plt.figure(figsize=(10,10))
        ax = plt.axes(projection='3d')
        ax.set_proj_type('ortho')

        ax.scatter3D(*hit_channel_coords.T)
        ax.plot3D(*linepoints.T)

        plt.xlim([-350,350])
        plt.ylim([-350,350])
        ax.set_zlim([-350,350])

        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')

        ax.view_init(x1, x2)

        plt.show()
        

        
# manually do crossproduct to avoid numpy overhead for small vectors
def cross(a, b):
    c = [a[1]*b[2] - a[2]*b[1],
         a[2]*b[0] - a[0]*b[2],
         a[0]*b[1] - a[1]*b[0]]
    return c