In [1]:
import numpy as np
import matplotlib.pyplot as plt
import bintrees as bt

In [2]:
#determinant test; return 1 if r is left of line pq, -1 if r right of line pq, 0 if on line pq
def detTest(px,py,qx,qy,rx,ry):
    D=np.array([[1,px,py],[1,qx,qy],[1,rx,ry]],dtype='i8')
    return int(np.sign(np.linalg.det(D)))
    
#returns true iff pq and rs cross
def cross(px,py,qx,qy,rx,ry,sx,sy):
    pqr=detTest(px,py,qx,qy,rx,ry)
    pqs=detTest(px,py,qx,qy,sx,sy)
    if not ((pqr==1 and pqs==-1) or (pqr==-1 and pqs==1)):
        return False
    rsp=detTest(rx,ry,sx,sy,px,py)
    rsq=detTest(rx,ry,sx,sy,qx,qy)
    if not ((rsp==1 and rsq==-1) or (rsp==-1 and rsq==1)):
        return False
    return True

In [3]:
#classes for points and segments
class Point:
    def __init__(self, x, y):
        self.x=x
        self.y=y
    def plot(self):
        plt.scatter(self.x,self.y)    
            
    def __lt__(self,other):
        return self.x<other.x or (self.x==other.x and self.y<other.y)        
    def __gt__(self,other):
        return self.x>other.x or (self.x==other.x and self.y>other.y)    
    def __eq__(self,other):
        return self.x==other.x and self.y==other.y    
    def __le__(self,other):
        return self.x<other.x or (self.x==other.x and self.y<=other.y)    
    def __ge__(self,other):
        return self.x>other.x or (self.x==other.x and self.y>=other.y)   
    def __ne__(self,other):
        return not (self==other)

#color is 0 for red, 1 for blue
class Segment:
    def __init__(self,pointA,pointB,color):
        if pointA<pointB:
            self.p=pointA
            self.q=pointB
        else:
            self.p=pointB
            self.q=pointA
        self.color=color
    
    def cmp(self,p):
        return detTest(self.p.x,self.p.y,self.q.x,self.q.y,p.x,p.y)
    
    def cross(self,segB):
        return cross(self.p.x,self.p.y,self.q.x,self.q.y,segB.p.x,segB.p.y,segB.q.x,segB.q.y)
    
    def col(self,seg):
        if self.cmp(seg.p)==0 and self.cmp(seg.q)==0:
            return True
        else:
            return False
    
    def slope(self):
        if (self.q.x-self.p.x)==0:
            return float('inf')
        else:
            return (self.q.y-self.p.y)/(self.q.x-self.p.x)
    def plot(self):
        if self.color==0:
            plt.plot([self.p.x,self.q.x],[self.p.y,self.q.y],'r')
        else:
            plt.plot([self.p.x,self.q.x],[self.p.y,self.q.y],'b')
            

    #ordered by aboveness (no intersections)
    #red below blue if the segments overlap
    def __lt__(self,other):
        return (other.cmp(self.p)==-1 or self.cmp(other.p)==1) or (self.p==other.p and self.q==other.q and self.color<other.color)
    def __gt__(self,other):
        return (other.cmp(self.p)==1 or self.cmp(other.p)==-1) or (self.p==other.p and self.q==other.q and self.color>other.color)
    def __eq__(self,other):
        return self.p==other.p and self.q==other.q and self.color==other.color
    def __le__(self,other):
        return self<other or self==other
    def __ge__(self,other):
        return self>other or self==other
    def __ne__(self,other):
        return not self==other
    
            

#flags are made from a segment and a endpoint of the segment
#0 is for startpt, 1 is for endpt
class Flag:
    def __init__(self, segment, endpoint):
        self.seg=segment
        self.pt=endpoint
        if self.pt==self.seg.p:
            self.type=0
        elif self.pt==self.seg.q:
            self.type=1
        else:
            raise NameError('point is not an endpoint of segment')
            
        sa=None
        sb=None
        oa=None
        ob=None
        
    def cmp(self, flagB):
        if self.pt<(flagB.pt):
            return -1
        elif self.pt>(flagB.pt):
            return 1
        else:
            if self.type==1 and flagB.type==0:
                return -1
            elif self.type==0 and flagB.type==1:
                return 1
            else:
                if self.seg.slope()<flagB.seg.slope():
                    return -1
                elif self.seg.slope()>flagB.seg.slope():
                    return 1
                else:
                    if self.type==1:
                        if self.seg.color==0 and flagB.seg.color==1:
                            return -1
                        elif self.seg.color==1 and flagB.seg.color==0:
                            return 1
                        else:
                            return 0
                    else:
                        if self.seg.color==1 and flagB.seg.color==0:
                            return -1
                        elif self.seg.color==0 and flagB.seg.color==1:
                            return 1
                        else:
                            return 0
    def __lt__(self,other):
        self.cmp(other)<0
    def __gt__(self,other):
        self.cmp(other)>0
    def __eq__(self,other):
        self.cmp(other)==0
    def __le__(self,other):
        self.cmp(other)<=0
    def __ge__(self,other):
        self.cmp(other)>=0
    def __ne__(self,other):
        self.cmp(other)!=0

In [4]:
#for building segments
class AllSegments:
    def __init__(self):
        self.red=[]
        self.blue=[]
        self.sortedFlags=[]
    #adding segments 
    def addRed(self,px,py,qx,qy):
        self.red.append(Segment(Point(px,py),Point(qx,qy),0))
    def addBlue(self,px,py,qx,qy):
        self.blue.append(Segment(Point(px,py),Point(qx,qy),1))
    
    #create square with endpts (-x,-x), (-x,x), (x,-x), (x,x)
    def addRedSq(self,x):
        self.addRed(-x,-x,-x,x)
        self.addRed(-x,-x,x,-x)
        self.addRed(x,x,x,-x)
        self.addRed(x,x,-x,x)
    
    #return the sorted list of flags
    def sortFlags(self):
        flags=[]
        for i in range(len(self.red)):
            flags.append(Flag(self.red[i],self.red[i].p))
            flags.append(Flag(self.red[i],self.red[i].q))
        for i in range(len(self.blue)):
            flags.append(Flag(self.blue[i],self.blue[i].p))
            flags.append(Flag(self.blue[i],self.blue[i].q))
        self.sortedFlags=sorted(flags)
    
    #plot all the segments
    def plot(self):
        for i in range(len(self.red)):
            self.red[i].plot()
            print((self.red[i].p.x,self.red[i].p.y), (self.red[i].q.x,self.red[i].q.y), 'red')
        for i in range(len(self.blue)):
            self.blue[i].plot()
            print((self.blue[i].p.x,self.blue[i].p.y), (self.blue[i].q.x,self.blue[i].q.y), 'blue')
        plt.axis('equal')
        plt.show()

In [6]:
#finding above below segment for all flags
class ActiveTree():
    def __init__(self,color):
        self.segs=bt.RBTree()
        self.segs.insert(Segment(Point(-9**9,-9**9),Point(9**9,-9**9),color),0)
        self.segs.insert(Segment(Point(-9**9,9**9),Point(9**9,9**9),color),0)

    def insert(self,seg):
        self.segs.insert(seg,0)
        below=self.segs.prev_item(seg)[0]
        above=self.segs.succ_item(seg)[0]
        return [below,above]
    
    def remove(self,seg):
        self.segs.remove(seg)
        
    def pointTest(self,point,color):
        pt=Segment(point,point,color)
        self.segs.insert(pt,0)
        below=self.segs.prev_item(pt)[0]
        above=self.segs.succ_item(pt)[0]
        self.segs.remove(pt)
        return [below,above]


def aboveBelow(data):
    data.sortFlags()
    activeRed=ActiveTree(0)
    activeBlue=ActiveTree(1)    
    for i in range(len(data.sortedFlags)):
        flag=data.sortedFlags[i]
        if flag.seg.color==0:
            if flag.type==0:
                ba=activeRed.insert(flag.seg)
                data.sortedFlags[i].sb=ba[0]
                data.sortedFlags[i].sa=ba[1]
            else:
                ba=activeRed.remove(flag.seg)
                data.sortedFlags[i].sb=flag.seg
                data.sortedFlags[i].sa=flag.seg
            ba=activeBlue.pointTest(flag.pt,0)
            data.sortedFlags[i].ob=ba[0]
            data.sortedFlags[i].oa=ba[1]
        
        else:
            if flag.type==0:
                ba=activeBlue.insert(flag.seg)
                data.sortedFlags[i].sb=ba[0]
                data.sortedFlags[i].sa=ba[1]
            else:
                ba=activeBlue.remove(flag.seg)
                data.sortedFlags[i].sb=flag.seg
                data.sortedFlags[i].sa=flag.seg
            ba=activeRed.pointTest(flag.pt,1)
            data.sortedFlags[i].ob=ba[0]
            data.sortedFlags[i].oa=ba[1]

In [9]:
#using doubly linked list would make merge and split faster, using RBTrees would make insert/remove faster
class Bundle():
    def __init__(self,seg):
        self.bundle=bt.RBTree()
        self.min=seg
        self.max=seg
        self.color=seg.color
        self.bundle.insert(key,seg)

    #only works for same colored bundles since aboveness of segments does not change
    def __lt__(self,other):
        return self.max<other.min
    def __gt__(self,other):
        return self.min>other.max
    def __eq__(self,other):
        return (self.min>other.min and self.max<other.min) or (self.min<other.min and self.max>other.max)
    def __le__(self,other):
        return self<other or self==other
    def __ge__(self,other):
        return self>other or self==other
    def __ne__(self,other):
        return not self==other
    
    def insert(self,seg):
        self.bundle.insert(seg,0)
    def remove(self,seg):
        self.bundle.insert(seg)
        
class BundleTree():
    def __init__(self):
        self.tree=bt.RBTree()
        
    

SyntaxError: unexpected EOF while parsing (<ipython-input-9-a9c24798e8ff>, line 1)