In [1]:
import numpy as np
import pyglet
from pyglet.window import key
import json

width = 800
height = 400

dt = 0.1
time = 0
mousex = 0
mousey = 0
gravity = False
training = False
mouseDown = False
bs = []
ss = []
selected = None

# training
best = 0 # furthest right
brain = None
best_brain = None

window = pyglet.window.Window(width, height)

def dist(x1, y1, x2, y2):
    return np.sqrt((x1-x2)*(x1-x2) + (y1-y2)*(y1-y2))

def avg_p():
    if len(bs) == 0:
        return {x: 0, y: 0}
    
    ax = 0
    ay = 0
    
    for b in bs:
        ax += b.x
        ay += b.y
        
    return {"x": ax/len(bs), "y": ay/len(bs)}

class Brain:
    def __init__(self, in_dim, out_dim, h_dim=20):
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.h_dim = h_dim
        self.rand_W()
        
    def rand_W(self):
        self.W1 = np.random.randn(self.h_dim, self.in_dim+2) * .4
        self.W2 = np.random.randn(self.out_dim, self.h_dim+1) * .4
        
    def process(self, springs):
        global time
        x = np.hstack((springs, np.array([np.sin(time)]), np.array([1]) ))
        h1 = np.tanh(self.W1.dot(x))
        h1 = np.hstack((h1, np.array([1])))
        h2 = np.tanh(self.W2.dot(h1))
        return h2

class Box:
    def __init__(self, x, y, m=1, i=-1):
        self.x = x
        self.y = y
        self.vx = 0
        self.vy = 0
        self.ax = 0
        self.ay = 0
        self.m = m
        
        if i == -1:
            self.id = np.random.randint(1e10)
        else:
            self.id = i
        
    def data(self):
        return {"id": self.id, "x": self.x, "y": self.y, "m": self.m}
        
    def update(self):
        global gravity
        if gravity:
            self.ay -= self.m * 8
            
        self.vx += self.ax / self.m * dt
        self.vy += self.ay / self.m * dt
        self.x += self.vx
        self.y += self.vy
        self.ax = 0
        self.ay = 0
        self.vx *= 0.8
        self.vy *= 0.8
        
        if self.x < 0:
            self.x = 0
            self.vx *= -0.5
            self.vy *= 0.0
            
        if self.y < 0:
            self.y = 0
            self.vy *= -0.5
            self.vx *= 0.0
            
        if self.x > width:
            self.x = width
            self.vx *= -0.5
            self.vy *= 0.0
            
        if self.y > height:
            self.y = height
            self.vy *= -0.5
            self.vx *= 0.0
        
    def render(self):
        dx = 10
        dy = 10
        x = self.x + dx/2
        y = self.y + dy/2
        pyglet.graphics.draw(4, pyglet.gl.GL_QUADS, ('v2f', [x, y, x-dx, y, x-dx, y-dy, x, y-dy]))
        
class Spring:
    def __init__(self, a, b):
        self.a = a
        self.b = b
        self.k = 40
        self.d = dist(self.a.x, self.a.y, self.b.x, self.b.y)
        self.id = np.random.randint(1e10)
        self.offset = 0 # control from brain
        self.diff = 0 # dist between a and b
        
    def data(self):
        return {"d": self.d, "a": self.a.id, "b": self.b.id}
        
    def update(self):
        curD = dist(self.a.x, self.a.y, self.b.x, self.b.y)
        self.diff = (self.d - curD)
        force = (self.diff + self.offset) * self.k
        ang = np.arctan2(self.a.y-self.b.y, self.a.x-self.b.x)
        
        c = np.cos(ang)
        s = np.sin(ang)
        
        self.a.ax += c * force * dt
        self.a.ay += s * force * dt
        self.b.ax -= c * force * dt
        self.b.ay -= s * force * dt
        
    def render(self):
        pyglet.graphics.draw(2, pyglet.gl.GL_LINES, 
            ("v2f", (self.a.x, self.a.y, self.b.x, self.b.y))
        )

def closest_box(x, y):
    global bs
    
    if len(bs) == 0:
        return None
    
    maxD = 1000000
    c = bs[0]
    for b in bs:
        d = dist(b.x, b.y, x, y)
        if d < maxD:
            maxD = d
            c = b
    return c

def box_with_id(i):
    global bs
    for b in bs:
        if b.id == i:
            return b
    print("box not found", i)
    return None

def export():
    global bs, ss
    
    data = {"boxes": [], "springs": []}
    
    for b in bs:
        data["boxes"].append(b.data())
        
    for s in ss:
        data["springs"].append(s.data())
        
    data_json = json.dumps(data)
    
    with open("save_file.txt", "w") as text_file:
        print(data_json, file=text_file)
    
def load():
    global bs, ss
    
    with open("save_file.txt", "r") as text_file:
        bs = []
        ss = []

        y = json.loads(text_file.read())

        for bdat in y["boxes"]:
            bs.append(Box(bdat["x"], bdat["y"], bdat["m"], bdat["id"]))

        for sdat in y["springs"]:
            a = box_with_id(sdat["a"])
            b = box_with_id(sdat["b"])

            spr = Spring(a, b)
            spr.d = sdat["d"]
            ss.append(spr)

@window.event
def on_key_press(symbol, modifiers):
    global bs, ss, selected
    
    if symbol == key.S:
        selected = closest_box(mousex, mousey)
        
@window.event
def on_key_release(symbol, modifiers):
    global bs, ss, selected, training, gravity, brain, best_brain, best
    
    if symbol == key.C:
        bs = []
        ss = []
        brain = None
        best_brain = False
        best = 0
    elif symbol == key.G:
        gravity = not gravity
    elif symbol == key.T:
        training = not training
        if not training:
            brain = best_brain
    elif symbol == key.B:
        bs.append(Box(mousex, mousey))
    elif symbol == key.E:
        export()
    elif symbol == key.L:
        load()
    elif symbol == key.S:
        b = closest_box(mousex, mousey)
        if selected and b and b is not selected:
            ss.append(Spring(selected, b))

@window.event        
def on_mouse_press(x, y, button, modifiers):
    global mouseDown, selected
    mouseDown = True
    selected = closest_box(x, y)

@window.event
def on_mouse_release(x, y, button, modifiers):
    global mouseDown
    mouseDown = False
    
@window.event
def on_mouse_motion(x, y, dx, dy):
    global mousex, mousey
    mousex = x
    mousey = y
    
@window.event
def on_mouse_drag(x, y, dx, dy, buttons, modifiers):
    global mousex, mousey
    mousex = x
    mousey = y

def brain_update():
    global brain, bs, ss
    
    if brain is None:
        return
    
    x = np.array([s.diff/s.d for s in ss])
    offs = brain.process(x)
    for i in range(len(offs)):
        ss[i].offset = offs[i] * ss[i].d * .1
        
def update(evt):
    global time, training, best, best_brain, brain, best_brain
    
    if training:
        for j in range(10):
            load()
            brain = Brain(len(ss), len(ss))
            if best_brain:
                brain.W1 = best_brain.W1.copy()
                brain.W2 = best_brain.W2.copy()
                
                sw1 = best_brain.W1.shape
                sw2 = best_brain.W2.shape
                
                brain.W1 += np.random.randn(sw1[0], sw1[1]) * .5
                brain.W2 += np.random.randn(sw2[0], sw2[1]) * .5
                
            sx = avg_p()["x"]

            for i in range(500):
                time += dt
                brain_update()
                for s in ss:
                    s.update()
                for b in bs:
                    b.update()
                    
            fx = avg_p()["x"]

            if fx - sx > best:
                best = fx - sx
                print("new best", best)
                best_brain = Brain(len(ss), len(ss))
                best_brain.W1 = brain.W1.copy()
                best_brain.W2 = brain.W2.copy()
            
        training = False
        print("done training")
            
    else:
        time += dt
        brain_update()
        for s in ss:
            s.update()
        for b in bs:
            b.update()
            
@window.event
def on_close():
    print("closed")
    
@window.event
def on_draw():
    global mouseDown, mousex, mousey
    
    window.clear()
    
    if mouseDown and len(bs) > 0:
        selected.x = mousex
        selected.y = mousey
        selected.vx = 0
        selected.vy = 0
    
    for b in bs:
        b.render()
    for s in ss:
        s.render()
    
pyglet.clock.schedule_interval(update, 1/60.0)
pyglet.app.run()

new best 0.2783365335826602
new best 4.575129968208046
new best 204.3474472455195
done training
done training
new best 221.14324280979832
done training
done training
done training
done training
done training
done training
done training
done training
done training
done training
closed
