In [None]:
from bootstrapSetup import *
from IPython.display import display, clear_output

dataset generated as 'loader
embedder created as transformer
Embedded shape: torch.Size([4, 128])
Using device: cuda


In [None]:
EXPERIMENT_NAME= "ToBeNamed"

# Improvements
* Check connectivity through connection map
* Check edgecase of wireset

In [101]:
def checkConnectivity(connectionMap, socketMap):
    curIndex = 0
    connections = []
    exhaustedSockets = set()

    def exhaustSocketSource(socket):
        if bool(socketMap[socket]): # If socket is connected to anything
            exhaustedSockets.add(socket)

        for c in socketMap[socket]:
            if c not in connections:
                connections.append(c)

    for inp, pos in inpSockets:
        exhaustSocketSource(inp)
    
    while curIndex < len(connections):
        c = connections[curIndex]
        for s in connectionMap[c]:
            if not s.isSource:
                exhaustedSockets.add(s)

            prefix, idnum, comptype = s.name

            if comptype:
                if (f"{prefix}{idnum}base" in exhaustedSockets and 
                    f"{prefix}{idnum}collector" in exhaustedSockets):
                    if f"{prefix}{idnum}emitter" not in connections:
                        exhaustSocketSource(f"{prefix}{idnum}emitter")

        curIndex += 1
    return connections, exhaustedSockets

def prePruningv2(socketMap, exhaustedSockets):
    connectedInps = []
    connectedOuts = []

    notConnectedGateIDs = []
    for s in socketMap:
        connected = s in exhaustedSockets
        prefix, idnum, comptype = s.name

        if comptype:
            if not connected and idnum not in notConnectedGateIDs:
                notConnectedGateIDs.append(idnum)
        else:
            if connected:
                if prefix == "inp":
                    connectedInps.append(s)
                else: 
                    connectedOuts.append(s)


    return connectedInps, connectedOuts, notConnectedGateIDs

def gateRemoval(circuit, gatesToClean):
    if len(gatesToClean) == 0:
        return circuit
    
    gateToCleanIndex = 0
    curGate = 0
    for i in range(circuit.shape[0]):
        for j in range(circuit.shape[1]):
            if circuit[i][j] > 1:
                if int(gatesToClean[gateToCleanIndex]) == curGate:
                    circuit[i][j] = 1
                    gateToCleanIndex += 1
                    if gateToCleanIndex >= len(gatesToClean):
                        return circuit

                curGate += 1
    
    return circuit


In [102]:
# Training Function
def trainCatModel(catModel, dataloader, embeddingModel, losses, stopAccuracy = 0.05, batchLimit = 100, PLOTUPDATES = True):
    catModel.train()
    updateRate = 10 # For plotting purposes

    loss_fn = nn.CrossEntropyLoss() #nn.MSELoss(reduction='sum')

    if PLOTUPDATES:
        fig, ax = plt.subplots(figsize=(8, 5))
    
    opt = torch.optim.Adam(catModel.parameters(), lr=1e-4) 


    batchNum = 0
    isTraining = True
    while isTraining:
        for batch, labels in dataloader:

            timesteps = torch.randint(0, scheduler.TrainSteps - 1, (batch.shape[0],), device='cpu').long()

            noisyImgs = scheduler.addNoise(batch, timesteps)

            embeddings = embeddingModel(labels)

            pred = catModel(noisyImgs.to(device).float(), timesteps.to(device), embeddings.to(device))

            loss = loss_fn(pred, batch.to(device).argmax(axis=1)) 

            opt.zero_grad()
            loss.backward()
            opt.step()
            losses.append(loss.item())

            if PLOTUPDATES:
                if len(losses) % updateRate == 0:
                    clear_output(wait=True)  
                    ax.clear()
                    plt.plot(losses)
                    plt.xlabel('Batch')
                    display(fig)
                    
            batchNum += 1
            if loss < stopAccuracy:
                    isTraining = False
                    break
            elif batchNum >= batchLimit:
                isTraining = False
                break
    
    if PLOTUPDATES:
        clear_output(wait=True)  
        ax.clear()
    plt.plot(losses)
    plt.xlabel('Batch')
    if PLOTUPDATES:
        display(fig)

    return catModel, losses

In [None]:
def generateSocketCombinations(inputSockets, relevantIndexes):
    #inputSockets = All available input sockets
    #relevantIndexes = the indexes of sockets that are connected to components
    
    usedSockets = [inputSockets[i][0] for i in relevantIndexes]
    
    # Generate all combinations (power set)
    orderList = []
    for r in range(len(usedSockets) + 1):
        for combo in combinations(usedSockets, r):
            orderList.append(list(combo))
    
    return orderList


__________

In [103]:
tc = np.array(
    [
        [1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 1, 1, 3, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [1, 2, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 0],
        [0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 3, 0, 0],
        [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ])

In [104]:
            #    socketMap, wireSets = GetSocketMap(circuits[i], inpSockets + outSockets)
            #     connectionMap = GetConnectionMap(socketMap)


            #     inps, outs = prePruning(connectionMap, socketMap)
            #     circuit, connectionMap, circuitDirty = afterPruning(circuits[i], wireSets, connectionMap)
            #     circuit = pruneExcessWires(circuit
socketMap, wireSets = GetSocketMap(tc, inpSockets + outSockets)
print(wireSets)

[[1 1 0 0 0 0 2 2 2 2 2 2 2]
 [0 1 1 3 2 2 2 0 0 0 0 0 0]
 [0 1 0 3 0 0 0 0 0 0 0 0 0]
 [0 1 0 3 0 0 0 0 0 0 0 0 0]
 [1 1 0 3 0 0 0 0 4 0 0 0 0]
 [1 1 3 3 0 0 0 0 4 0 5 0 0]
 [0 1 0 0 0 0 0 0 4 0 5 0 0]
 [0 1 0 0 0 0 0 0 4 0 5 0 0]
 [0 1 0 6 6 6 0 0 0 0 5 0 0]
 [0 1 0 0 0 6 0 0 0 0 0 0 0]
 [0 0 0 0 0 6 0 0 0 0 0 0 0]
 [0 0 0 0 0 6 6 6 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0]]


In [105]:
connectionMap = GetConnectionMap(socketMap)

In [106]:
socketMap

{inp0: {1},
 inp1: {1},
 inp2: set(),
 inp3: set(),
 out0: {2},
 out1: set(),
 out2: set(),
 out3: set(),
 NAND0base: {1},
 NAND0collector: {3},
 NAND0emitter: {2},
 AND1base: {1},
 AND1collector: {1},
 AND1emitter: {3},
 AND2base: set(),
 AND2collector: {4},
 AND2emitter: set(),
 NAND3base: set(),
 NAND3collector: {5},
 NAND3emitter: set()}

In [107]:
connectionMap

{1: [inp0, inp1, NAND0base, AND1base, AND1collector],
 2: [out0, NAND0emitter],
 3: [NAND0collector, AND1emitter],
 4: [AND2collector],
 5: [NAND3collector]}

* Go through each inpsocket, and add their connections to a list
* keep all exhausted sockets in a list
* pop a connection from the list and go to 1.
* if base and collector in set, add emitter

Should this be a queue and set? a dictionary?

In [108]:
def checkConnectivity(connectionMap, socketMap):
    curIndex = 0
    connections = []
    exhaustedSockets = set()

    def exhaustSocketSource(socket):
        if bool(socketMap[socket]): # If socket is connected to anything
            exhaustedSockets.add(socket)

        for c in socketMap[socket]:
            if c not in connections:
                connections.append(c)

    for inp, pos in inpSockets:
        exhaustSocketSource(inp)
    
    while curIndex < len(connections):
        c = connections[curIndex]
        for s in connectionMap[c]:
            if not s.isSource:
                exhaustedSockets.add(s)

            prefix, idnum, comptype = s.name

            if comptype:
                if (f"{prefix}{idnum}base" in exhaustedSockets and 
                    f"{prefix}{idnum}collector" in exhaustedSockets):
                    if f"{prefix}{idnum}emitter" not in connections:
                        exhaustSocketSource(f"{prefix}{idnum}emitter")

        curIndex += 1
    return connections, exhaustedSockets


connectedSets, exhaustedSockets = checkConnectivity(connectionMap, socketMap)

In [109]:
for s in socketMap:
    print(s in exhaustedSockets, s)

True inp0
True inp1
False inp2
False inp3
True out0
False out1
False out2
False out3
True NAND0base
True NAND0collector
True NAND0emitter
True AND1base
True AND1collector
True AND1emitter
False AND2base
False AND2collector
False AND2emitter
False NAND3base
False NAND3collector
False NAND3emitter


In [110]:
def prePruningv2(socketMap, exhaustedSockets):
    connectedInps = []
    connectedOuts = []

    notConnectedGateIDs = []
    for s in socketMap:
        connected = s in exhaustedSockets
        prefix, idnum, comptype = s.name

        if comptype:
            if not connected and idnum not in notConnectedGateIDs:
                notConnectedGateIDs.append(idnum)
        else:
            if connected:
                if prefix == "inp":
                    connectedInps.append(s)
                else: 
                    connectedOuts.append(s)


    return connectedInps, connectedOuts, notConnectedGateIDs
        
connectedInps, connectedOuts, notConnectedGateIDs = prePruningv2(socketMap, exhaustedSockets)

In [111]:
for idx in notConnectedGateIDs:
    print(f"Gate {idx} is not connected to any input socket")

Gate 2 is not connected to any input socket
Gate 3 is not connected to any input socket


In [112]:
def gateRemoval(circuit, gatesToClean):
    if len(gatesToClean) == 0:
        return circuit
    
    gateToCleanIndex = 0
    curGate = 0
    for i in range(circuit.shape[0]):
        for j in range(circuit.shape[1]):
            if circuit[i][j] > 1:
                if int(gatesToClean[gateToCleanIndex]) == curGate:
                    circuit[i][j] = 1
                    gateToCleanIndex += 1
                    if gateToCleanIndex >= len(gatesToClean):
                        return circuit

                curGate += 1
    
    return circuit

tc2 = gateRemoval(tc, notConnectedGateIDs)

In [113]:
notConnectedGateIDs

['2', '3']

In [114]:
tc2

array([[1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
       [0, 1, 1, 3, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [1, 2, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0],
       [0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
       [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [116]:
def afterPruningv2(circuit, wiresets, exhaustedSockets, socketMap):
    # From exhausted sockets, get all used connections
    # iterate over wirestes and remove all not used

    usedSets = set()
    for s in exhaustedSockets:
        for c in socketMap[s]:
            usedSets.add(c)

    circuitWasUpdated = False
    for i in range(circuit.shape[0]):
        for j in range(circuit.shape[1]):
            if wiresets[i][j] not in usedSets and circuit[i][j] == 1:
                circuit[i][j] = 0
                circuitWasUpdated = True

    return circuit, circuitWasUpdated

afterPruningv2(tc2, wireSets, exhaustedSockets, socketMap) 

(array([[1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 1, 1, 3, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 True)