In [2]:
import jax
import numpy as np
import matplotlib.pyplot as plt

# The default of float16 can lead to discrepancies between outputs of
# the compiled model and the RASP program.
jax.config.update('jax_default_matmul_precision', 'float32')

from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp

In [3]:
#Calculates some weight statistic from a weight counter
def calculateWeightStatistics(weightCounter: dict, doPrint = False):
    totalValues = 0
    for _, n in weightCounter.items():
        totalValues+=n
    maxValue = max(weightCounter)
    minValue = min(weightCounter)
    zeroPercentage = 100*weightCounter[0]/totalValues if 0 in weightCounter else 0
    numberOfUniqueValues = len(weightCounter)

    if doPrint:
        print("N: %d\t min/max: %.2f/%.2f\t nValues: %d\t percentageZero: %.2f" % 
          (totalValues, minValue, maxValue, numberOfUniqueValues, zeroPercentage))
    return {"totalValues":totalValues, "maxValue": maxValue, "minValue": minValue, "zeroPercentage": zeroPercentage, "numberOfUniqueValues": numberOfUniqueValues}

In [4]:
class Model:
    def __init__(self, raspFunction: rasp.SOp, inputs, seqLength: int, name: str):
        self.raspFunction = raspFunction
        self.inputs = inputs
        self.seqLength = seqLength
        self.model = compiling.compile_rasp_to_model(self.raspFunction, self.inputs, self.seqLength, compiler_bos="BOS")
        self.name = name

        self.weightStatistics = {}
        self.updateWeightStatistics()

    def updateWeightStatistics(self):
        self.weightStatistics = {}
        
        totalCounter = {}
        for name1, layer in self.model.params.items():
            self.weightStatistics[name1] = {}
            #print(name1, type(layer))
            for name2, weight in layer.items():
                weightCounter = {}
                #print("\t", name2, type(weight))

                #Find unique weights and count instances for the weights
                for t in weight.flatten():
                    t = float(t)
                    if t in weightCounter:
                        weightCounter[t]+=1
                    else:
                        weightCounter[t]=1

                #print("\t",end="  ")
                self.weightStatistics[name1][name2] = calculateWeightStatistics(weightCounter)

                #Appends the weight counts to the total counts
                for number, count in weightCounter.items():
                    if number in totalCounter:
                        totalCounter[number]+=count
                    else:
                        totalCounter[number]=count

        #print("\nTotal statistics")
        self.weightStatistics["total"] = calculateWeightStatistics(totalCounter)

    def printWeightStatistics(self, includeB=False):
        print(self.model.model_config)
        print("\nLayer analysis:")

        for name1, _ in self.weightStatistics.items():
            print(name1)
            if name1=="total":
                weightStats=self.weightStatistics[name1]
                print("\t  N: %d\t min/max: %.2f/%.2f\t nValues: %d\t percentageZero: %.2f" % 
                    (weightStats["totalValues"], weightStats["minValue"], weightStats["maxValue"], weightStats["numberOfUniqueValues"], weightStats["zeroPercentage"]))
                continue
            
            for name2, weightStats in self.weightStatistics[name1].items():
                if name2=="b" and includeB!=True:
                    continue
                print("\t", name2)
                print("\t  N: %d\t min/max: %.2f/%.2f\t nValues: %d\t percentageZero: %.2f" % 
                    (weightStats["totalValues"], weightStats["minValue"], weightStats["maxValue"], weightStats["numberOfUniqueValues"], weightStats["zeroPercentage"]))
    
    
    def evaluate(self, input):
        return self.model.apply(input).decoded
    



In [5]:
inputs = [0,1,2,3,4,5]
print(inputs[1:])

[1, 2, 3, 4, 5]


In [6]:
acceptedNamesAndInput = {"reverse": ["a","b","c","d","e"], #Tokens doesn't matter much. Only the quantity influnce the results due to encoding (I think)
                         "hist": ["a","b","c","d"], #Tokens doesn't matter much. Only the quantity influnce the results due to encoding (I think)
                         "sort": [1,2,3,4,5,6], #[0,1,2,3,4,5,6]    Seems to fail sometimes if 0 is included (irrespktive of if 0 is in the failed input or not, don't know why)
                         "most-freq": [1,2,3,4,5],
                         "shuffle_dyck1": ["(",")"],
                         "shuffle_dyck2": ["(",")","{","}"]}     #Could theoretically be adapted into shuffle dyck-k but would still require unique tokens for each k

def generateData(name: str, maxSeqLength: int, size: int):
    data = [None]*size

    match name:
        case "reverse":
            acceptedTokens = acceptedNamesAndInput[name]

            for i in range(size):
                inputLength = np.random.randint(2, maxSeqLength+1)  #Uniformly distributed between 2 and max length

                inputSeq = []
                outputSeq = []
                for t in np.random.choice(acceptedTokens, inputLength):
                    inputSeq.append(t)
                    outputSeq.insert(0,t)
                inputSeq.insert(0,"BOS")
                outputSeq.insert(0,"BOS")

                data[i] = (inputSeq, outputSeq)

        case "hist":
            acceptedTokens = acceptedNamesAndInput[name]  
            
            for i in range(size):
                inputLength = np.random.randint(2, maxSeqLength+1)  #Uniformly distributed between 2 and max length

                inputSeq = []
                tokenCounter = dict(zip(acceptedTokens, [0]*len(acceptedTokens)))   #Counter built during generating input
                for t in np.random.choice(acceptedTokens, inputLength):
                    inputSeq.append(t)
                    tokenCounter[t]+=1
    
                outputSeq = []
                for t in inputSeq:  #Fill output according to token counter
                    outputSeq.append(tokenCounter[t])

                inputSeq.insert(0,"BOS")
                outputSeq.insert(0,"BOS")

                data[i] = (inputSeq, outputSeq)

        case "sort":
            acceptedTokens = acceptedNamesAndInput[name]  
            
            for i in range(size):
                inputLength = np.random.randint(2, maxSeqLength+1)  #Uniformly distributed between 2 and max length

                inputSeq = []
                outputSeq = []
                for t in np.random.choice(acceptedTokens, inputLength):
                    inputSeq.append(t)
                    outputSeq.append(t)
    
                inputSeq.insert(0,"BOS")
                outputSeq.sort()
                outputSeq.insert(0,"BOS")

                data[i] = (inputSeq, outputSeq)

        case "most-freq":   #sort based on most frequent token with original position as tie breaker
            acceptedTokens = acceptedNamesAndInput[name]  

            for i in range(size):
                inputLength = np.random.randint(2, maxSeqLength+1)  #Uniformly distributed between 2 and max length

                inputSeq = []
                tempSeq = []
                tokenCounter = dict(zip(acceptedTokens, [0]*len(acceptedTokens)))   #Counter built during generating input
                for t in np.random.choice(acceptedTokens, inputLength):
                    inputSeq.append(t)
                    tokenCounter[t]+=1
                    tempSeq.append(t)    
                
                tempSeq.sort(key = (lambda x: -tokenCounter[x]))  #Sort the list in descending order of frequency

                outputSeq = tempSeq

                #Groups the tokens (Apparently not done by the Tracr solution)
                """
                outputSeq = []
                for t in tempSeq:
                    if t not in outputSeq:
                        for ii in range(tokenCounter[t]):
                            outputSeq.append(t)
                """

                inputSeq.insert(0,"BOS")
                outputSeq.insert(0,"BOS")

                data[i] = (inputSeq, outputSeq)

        case "shuffle_dyck1":
            acceptedTokens = acceptedNamesAndInput[name]

            for i in range(size):
                for ii in range(3):     #Ensures that roughly one out of eight sequences has an odd length
                    inputLength = np.random.randint(2, maxSeqLength+1)  #Uniformly distributed between 2 and max length
                    if inputLength%2==0:
                        break

                inputSeq = []
                tokenCount = {"(":0,")":0}
                tokenProb = np.zeros(len(acceptedTokens))   #Live probabilty distribution to more evenly distribute the balanced and unblanaced sequences
                tokenProb[1] = 1/(inputLength+1)
                tokenProb[0] = 1 - tokenProb[1]

                #Build the sequence token by token and ensuring the probability of drawing a balanced sequence is always higher than drawing an unbalanced sequence
                for ind in range(inputLength):
                    t = np.random.choice(acceptedTokens, 1, p=tokenProb)[0]
                    tokenCount[t]+=1
                    inputSeq.append(t)

                    tokenDiff = tokenCount["("]-tokenCount[")"]
                    if tokenDiff == 0:  #High probability of begining paranthesis if balanced
                        tokenProb[1] = 1/(inputLength+1)
                        tokenProb[0] = 1 - tokenProb[1]
                    elif tokenDiff > 0:   #High probability of end paranthesis if more begining paranthesis
                        tokenProb[0] = 1/((inputLength+1)*tokenDiff)
                        tokenProb[1] = 1 - tokenProb[0]
                    else: #High probability of begining paranthesis if more end paranthesis
                        tokenProb[1] = 1/((inputLength+1)*(-tokenDiff))
                        tokenProb[0] = 1 - tokenProb[1]
                
                #Checks for balance
                balanceCounter=0
                for t in inputSeq:
                    if t=="(":
                        balanceCounter+=1
                    else:
                        balanceCounter-=1
                    if balanceCounter<0:
                        break
                
                if balanceCounter!=0:
                    outputSeq = [0]*len(inputSeq)
                else:
                    outputSeq = [1]*len(inputSeq)

                inputSeq.insert(0,"BOS")
                outputSeq.insert(0,"BOS")

                data[i] = (inputSeq, outputSeq)

        case "shuffle_dyck2":
            acceptedTokens = acceptedNamesAndInput[name]

            for i in range(size):
                for ii in range(3):     #Ensures that roughly one out of eight sequences has an odd length
                    inputLength = np.random.randint(2, maxSeqLength+1)  #Uniformly distributed between 2 and max length
                    if inputLength%2==0:
                        break

                inputSeq = []
                tokenCount = {"(":0,")":0,"{":0,"}":0}
                tokenProb = np.zeros(len(acceptedTokens))   #Live probabilty distribution to more evenly distribute the balanced and unblanaced sequences
                tokenProb[1] = 1/((inputLength+1)*2)
                tokenProb[0] = 1/2 - tokenProb[1]
                tokenProb[3] = tokenProb[1]
                tokenProb[2] = tokenProb[0]

                #Build the sequence token by token and ensuring the probability of drawing a balanced sequence is always higher than drawing an unbalanced sequence
                for ind in range(inputLength):
                    t = np.random.choice(acceptedTokens, 1, p=tokenProb)[0]
                    tokenCount[t]+=1
                    inputSeq.append(t)

                    tokenDiff1 = tokenCount["("]-tokenCount[")"]
                    tokenDiff2 = tokenCount["{"]-tokenCount["}"]
                    if tokenDiff1 == 0 and tokenDiff2==0:  #High probability of begining paranthesis if balanced
                        tokenProb[1] = 1/((inputLength+1)*2)
                        tokenProb[0] = 1/2 - tokenProb[1]
                        tokenProb[3] = tokenProb[1]
                        tokenProb[2] = tokenProb[0]
                    #High probability of end paranthesis if more begining paranthesis
                    elif tokenDiff2 > 0 and tokenDiff1 > 0:
                        tokenProb[0] = 1/((inputLength+1)*tokenDiff1*2)
                        tokenProb[2] = 1/((inputLength+1)*tokenDiff2*2)
                        tokenProb[1] = 1/2 - tokenProb[0]
                        tokenProb[3] = 1/2 - tokenProb[2]
                    elif tokenDiff1 > 0 and tokenDiff2==0: 
                        tokenProb[1] = 1 - 1/((inputLength+1)*tokenDiff1)
                        split = 1 - tokenProb[1]    #The reminder of probability to distribute
                        tokenProb[2] = split - split/((inputLength+1))    #More likely to start a new parenthesis than break sequence
                        split = split - tokenProb[2]
                        tokenProb[0] = split/2
                        tokenProb[3] = split/2
                    elif tokenDiff2 > 0 and tokenDiff1==0:   
                        tokenProb[3] = 1 - 1/((inputLength+1)*tokenDiff2)
                        split = 1 - tokenProb[3]    #The reminder of probability to distribute
                        tokenProb[0] = split - split/((inputLength+1))    #More likely to start a new parenthesis than break sequence
                        split = split - tokenProb[0]
                        tokenProb[1] = split/2
                        tokenProb[2] = split/2
                    #High probability of begining paranthesis if more end paranthesis
                    elif tokenDiff2 < 0 and tokenDiff1 < 0:
                        tokenProb[1] = 1/((inputLength+1)*(-tokenDiff1)*2)
                        tokenProb[3] = 1/((inputLength+1)*(-tokenDiff2)*2)
                        tokenProb[0] = 1/2 - tokenProb[1]
                        tokenProb[2] = 1/2 - tokenProb[3]
                    elif tokenDiff1 < 0 and tokenDiff2 == 0:
                        tokenProb[0] = 1 - 1/((inputLength+1)*(-tokenDiff1))
                        split = 1 - tokenProb[0]    #The reminder of probability to distribute
                        tokenProb[2] = split - split/((inputLength+1))    #More likely to start a new parenthesis than break sequence
                        split = split - tokenProb[2]
                        tokenProb[1] = split/2
                        tokenProb[3] = split/2
                    elif tokenDiff2 < 0 and tokenDiff1 == 0:
                        tokenProb[2] = 1 - 1/((inputLength+1)*(-tokenDiff2))
                        split = 1 - tokenProb[2]    #The reminder of probability to distribute
                        tokenProb[0] = split - split/((inputLength+1))    #More likely to start a new parenthesis than break sequence
                        split = split - tokenProb[0]
                        tokenProb[1] = split/2
                        tokenProb[3] = split/2
                    #Higher probability to balance the sequence if currently unbalanced
                    elif tokenDiff1 > 0 and tokenDiff2 < 0:
                        tokenProb[1] = 1/((inputLength+1)*tokenDiff1*2)
                        tokenProb[2] = 1/((inputLength+1)*(-tokenDiff2)*2)
                        tokenProb[0] = 1/2 - tokenProb[1]
                        tokenProb[3] = 1/2 - tokenProb[2]
                    elif tokenDiff2 > 0 and tokenDiff1 < 0:
                        tokenProb[3] = 1/((inputLength+1)*tokenDiff2*2)
                        tokenProb[0] = 1/((inputLength+1)*(-tokenDiff1)*2)
                        tokenProb[1] = 1/2 - tokenProb[0]
                        tokenProb[2] = 1/2 - tokenProb[3]
                
                #Checks for balance
                balanceCounter=[0,0]
                for t in inputSeq:
                    if t=="(":
                        balanceCounter[0]+=1
                    if t==")":
                        balanceCounter[0]-=1
                    if t=="{":
                        balanceCounter[1]+=1
                    if t=="}":
                        balanceCounter[1]-=1
                    
                    if balanceCounter[0]<0 or balanceCounter[1]<0:
                        break
                
                if balanceCounter[0]!=0 or balanceCounter[1]!=0:
                    outputSeq = [0]*len(inputSeq)
                else:
                    outputSeq = [1]*len(inputSeq)

                inputSeq.insert(0,"BOS")
                outputSeq.insert(0,"BOS")

                data[i] = (inputSeq, outputSeq)


        case _:
            print(name, "is not an accepted name the accepted names are",acceptedNamesAndInput)
            return None

    return data

data = generateData("shuffle_dyck2", 5, 100)
print(data[:5])

def generateModel(name: str, maxLength: int) -> Model:
    model = None
    match name:
        case "reverse":
            inputs = {t for t in acceptedNamesAndInput[name]}
            model = Model(lib.make_reverse(rasp.tokens), inputs, maxLength, name)

        case "hist":
            inputs = {t for t in acceptedNamesAndInput[name]}
            model = Model(lib.make_hist(), inputs, maxLength, name)

        case "sort":
            inputs = {t for t in acceptedNamesAndInput[name]}
            model = Model(lib.make_sort(rasp.tokens, rasp.tokens, max_seq_len=maxLength, min_key=min(inputs)), inputs, maxLength, name)

        case "most-freq":
            inputs = {t for t in acceptedNamesAndInput[name]}
            model = Model(lib.make_sort_freq(maxLength), inputs, maxLength, name)

        case "shuffle_dyck1":
            inputs = {t for t in acceptedNamesAndInput[name]}
            model = Model(lib.make_shuffle_dyck(["()"]), inputs, maxLength, name)
        
        case "shuffle_dyck2":
            inputs = {t for t in acceptedNamesAndInput[name]}
            model = Model(lib.make_shuffle_dyck(["()","{}"]), inputs, maxLength, name)

        case _:
            print(name, "is not an accepted name the accepted names are",acceptedNamesAndInput)
            return None

    return model


[(['BOS', '{', '{', '}', '}'], ['BOS', 1, 1, 1, 1]), (['BOS', '{', '}', '{', ')'], ['BOS', 0, 0, 0, 0]), (['BOS', '(', ')', ')', '(', '('], ['BOS', 0, 0, 0, 0, 0]), (['BOS', '{', '}', '(', ')'], ['BOS', 1, 1, 1, 1]), (['BOS', '(', ')', '('], ['BOS', 0, 0, 0])]


In [7]:
#Prints some statistics on the generated dyck data
def checkDyckBalance(data):
    oddLength = 0
    balanced = 0

    for (input, output) in data:
        if len(input)%2==0 :    #length + bos
            oddLength +=1
        if output[1]==1:
            balanced+=1
    
    oddLength /= len(data)/100
    balanced /= len(data)/100

    print("Percentage of data which is:")
    print("Of odd length:", oddLength)
    print("Balanced:", balanced)

print("dyck1")
checkDyckBalance(generateData("shuffle_dyck1", 5, 10000))
checkDyckBalance(generateData("shuffle_dyck1", 10, 10000))
checkDyckBalance(generateData("shuffle_dyck1", 15, 10000))
checkDyckBalance(generateData("shuffle_dyck1", 50, 10000))

print("\ndyck2")
checkDyckBalance(generateData("shuffle_dyck2", 5, 10000))
checkDyckBalance(generateData("shuffle_dyck2", 10, 10000))
checkDyckBalance(generateData("shuffle_dyck2", 15, 10000))
checkDyckBalance(generateData("shuffle_dyck2", 50, 10000))

#Seems to work fairly well, roughly between 40 and 50% is balanced depending on the maximum size

dyck1
Percentage of data which is:
Of odd length: 12.13
Balanced: 43.09
Percentage of data which is:
Of odd length: 9.33
Balanced: 49.61
Percentage of data which is:
Of odd length: 12.62
Balanced: 47.66
Percentage of data which is:
Of odd length: 11.73
Balanced: 52.82

dyck2
Percentage of data which is:
Of odd length: 12.46
Balanced: 40.47
Percentage of data which is:
Of odd length: 8.95
Balanced: 47.3
Percentage of data which is:
Of odd length: 12.17
Balanced: 46.7
Percentage of data which is:
Of odd length: 11.57
Balanced: 50.21


In [8]:
#Returns the boolean result for each case in the data set
def evaluateModel(model: Model, data):
    print("Evaluating model:",model.name)
    N=len(data)
    booleanAccuracy = np.zeros(N)
    
    for i in range(N):
        inputSeq, trueOutputSeq = data[i]
        outputSeq = model.evaluate(inputSeq)

        seqLength = len(trueOutputSeq)
        sameToken = np.zeros(seqLength)
        for ii in range(seqLength):
            sameToken[ii] = (outputSeq[ii]==trueOutputSeq[ii])
        
        booleanAccuracy[i] = (np.sum(sameToken) == seqLength)

        #Add loading bar to keep track of progress

    return booleanAccuracy


In [9]:
name = "shuffle_dyck2"
maxSeqLen = 5
data = generateData(name, maxSeqLen, 1000)
model = generateModel(name, maxSeqLen)

print(data[:5])

booleanAccuracy = evaluateModel(model, data)
accuracy=np.mean(booleanAccuracy)
print("Accuracy:",accuracy)

[(['BOS', '(', ')', '(', ')'], ['BOS', 1, 1, 1, 1]), (['BOS', '{', '}', '{', '('], ['BOS', 0, 0, 0, 0]), (['BOS', '{', '}', '(', ')'], ['BOS', 1, 1, 1, 1]), (['BOS', '(', ')', '(', ')'], ['BOS', 1, 1, 1, 1]), (['BOS', '(', ')', '{'], ['BOS', 0, 0, 0])]
Evaluating model: shuffle_dyck2
Accuracy: 1.0


In [10]:
print(np.argwhere(booleanAccuracy-1))
print(data[7])

print(model.evaluate(data[7][0]))

[]
(['BOS', '{', '('], ['BOS', 0, 0])
['BOS', False, False]


In [11]:
cat = {0:1, 2:4, 1:6}
print(len(cat))

3


In [12]:


#Quick function to check for if all "b" weights are truly zero
def analyzeB(model: Model):
    for name1, layer in model.model.params.items():
        for name2, weight in layer.items():
            if name2!="b":
                continue
            weightCounter = {}

            #Find unique weights and count instances for the weights
            for t in weight.flatten():
                t = float(t)
                if t in weightCounter:
                    weightCounter[t]+=1
                else:
                    weightCounter[t]=1

            calculateWeightStatistics(weightCounter, True)

#name = "reverse"
name = "reverse"
maxSeqLen = 5
data = generateData(name, maxSeqLen, 1000)
model = generateModel(name, maxSeqLen)

#analyzeB(model)
model.updateWeightStatistics()
model.printWeightStatistics()

TransformerConfig(num_heads=1, num_layers=4, key_size=12, mlp_hidden_size=30, dropout_rate=0.0, activation_function=<jax._src.custom_derivatives.custom_jvp object at 0x000002B8A6099820>, layer_norm=False, causal=False)

Layer analysis:
pos_embed
	 embeddings
	  N: 270	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 98.15
token_embed
	 embeddings
	  N: 315	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 95.56
transformer/layer_0/attn/key
	 w
	  N: 540	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 98.89
transformer/layer_0/attn/linear
	 w
	  N: 540	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 99.81
transformer/layer_0/attn/query
	 w
	  N: 540	 min/max: 0.00/100.00	 nValues: 2	 percentageZero: 95.19
transformer/layer_0/attn/value
	 w
	  N: 540	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 99.81
transformer/layer_0/mlp/linear_1
	 w
	  N: 1350	 min/max: -75.00/100.00	 nValues: 13	 percentageZero: 98.37
transformer/layer_0/mlp/linear_2
	 w
	  N: 1350	 min/max: -1.00/1.00	 nValues

In [13]:
#Add noise to the model weights according too noiseType, amount and param
def addNoise(model: Model, noiseType = "bitFlip", amount=1, param = 0.1, includeEncoding = False):
    noiseTypes = ["bitFlip", "gaussian", "flipFirst", "temp"]
    if noiseType not in noiseTypes:
        print("Error: noiseType needs to be one of", noiseTypes)
        return
    
    match noiseType:
        #Flip binary bits 
        #If amount is a integer it flips that many random bits, if it is float it flips that fraction of bits
        case "bitFlip":
            #find binary weights in the model
            #Ensure that the weights are correctly changed before commiting to design. If assignment doesn't work I'll save the keys to access

            #Saves the keys to access all the layers with binary weights as well as the weight statistics for that layer
            binaryWeights = [] 
            totalCount = 0
            for name1, _ in model.weightStatistics.items():
                if name1=="total":
                    continue
                for name2, weightStats in model.weightStatistics[name1].items():
                    if weightStats["numberOfUniqueValues"]==2:
                        if includeEncoding == False and name2=="embeddings":
                            continue
                        binaryWeights.append((name1, name2, weightStats))
                        totalCount+=weightStats["totalValues"]

            if type(amount)==int:
                #Randomly selects "amount" bits to flip
                #The probability is equal for all applicable parameters
                for i in range(amount):
                    #Parameter used to figure out the index where flip happens
                    index = np.random.randint(totalCount)
                    for name1, name2, stats in binaryWeights:
                        layerShape = model.model.params[name1][name2].shape
                        layerCount = layerShape[0]*layerShape[1]

                        #Check if this layer is the layer where the flip happens
                        if index>=layerCount:   #Not the correct layer
                            index-=layerCount
                            continue
                        
                        #Flip happens on this layer
                        index = (index//layerShape[1], index%layerShape[1])
                        model.model.params[name1][name2] = model.model.params[name1][name2].at[index[0],index[1]].set(
                            stats["maxValue"] - float(model.model.params[name1][name2][index[0],index[1]])
                        )

                        #print("Flip at:",name1,name2,index)

                        break

            elif type(amount)==float:
                print("Percentage bitflip not yet implemented")
                return
            else:
                print("Error: amount needs to be int or float")

        case "gaussian":

            if type(amount)==int:
                print("Counted gaussian not yet implemented")
                return
            
            #Adds gaussian noise with standard deviation "param" to "amount" fraction of the weights
            elif type(amount)==float:
                for name1, _ in model.weightStatistics.items():
                    if name1=="total":
                        continue
                    for name2, weightStats in model.weightStatistics[name1].items():
                        if name2=="b":
                            continue
                        if includeEncoding == False and name2=="embeddings":
                            continue

                        layerShape = model.model.params[name1][name2].shape
                        model.model.params[name1][name2] = model.model.params[name1][name2] + \
                            np.where(np.random.rand(layerShape[0], layerShape[1])<amount, np.random.normal(0, param, layerShape), 0)
                        
                return
            else:
                print("Error: amount needs to be int or float")

        case "flipFirst":
            model.model.params["transformer/layer_0/attn/key"]["w"] = model.model.params["transformer/layer_0/attn/key"]["w"].at[0,0].set(1)
            #weights = model.model.params["transformer/layer_0/attn/key"]["w"]   #This assignment does not work. Guessing it copies due to strict immutability
            #weights = weights.at[0,0].set(1)
            print(model.model.params["transformer/layer_0/attn/key"]["w"][0,0])
            print(model.model.params["transformer/layer_0/attn/key"]["w"])

        case "temp":
            shape = model.model.params["transformer/layer_0/attn/key"]["w"].shape
            print(shape)
            model.model.params["transformer/layer_0/attn/key"]["w"] = model.model.params["transformer/layer_0/attn/key"]["w"]+np.ones(shape) #Easy lol
            return


        case _:
            print("Error: noiseType not implemented")

    
    return

name = "reverse"
maxSeqLen = 5
model = generateModel(name, maxSeqLen)
addNoise(model, noiseType="gaussian", amount=1.0, param=0.001)
model.updateWeightStatistics()
model.printWeightStatistics()

TransformerConfig(num_heads=1, num_layers=4, key_size=12, mlp_hidden_size=30, dropout_rate=0.0, activation_function=<jax._src.custom_derivatives.custom_jvp object at 0x000002B8A6099820>, layer_norm=False, causal=False)

Layer analysis:
pos_embed
	 embeddings
	  N: 270	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 98.15
token_embed
	 embeddings
	  N: 315	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 95.56
transformer/layer_0/attn/key
	 w
	  N: 540	 min/max: -0.00/1.00	 nValues: 540	 percentageZero: 0.00
transformer/layer_0/attn/linear
	 w
	  N: 540	 min/max: -0.00/1.00	 nValues: 540	 percentageZero: 0.00
transformer/layer_0/attn/query
	 w
	  N: 540	 min/max: -0.00/100.00	 nValues: 539	 percentageZero: 0.00
transformer/layer_0/attn/value
	 w
	  N: 540	 min/max: -0.00/1.00	 nValues: 540	 percentageZero: 0.00
transformer/layer_0/mlp/linear_1
	 w
	  N: 1350	 min/max: -75.00/100.00	 nValues: 1350	 percentageZero: 0.00
transformer/layer_0/mlp/linear_2
	 w
	  N: 1350	 min/max: -1.00/1.00

In [14]:
booleanAccuracy = evaluateModel(model, data)
accuracy=np.mean(booleanAccuracy)
print("Accuracy:",accuracy)

Evaluating model: reverse
Accuracy: 0.831


## Training

In [None]:
#Trying to figure out what kind of haiku model the tracr models are (how do they relate to pure haiku models) and how I can train these models

## Notes

#### Testing the base functions and generating the test data

The sort function does not have a 100% accuracy. This only seems to apply when including input token 0, if only using 1 and up it seems to work. A cursory analysis would suggest that the min value in the sort function is multiplied with the indicies which makes it indistinguishable if the minimum value is 1.

The most-freq function does not work in the same method as the original RASP paper (despite the Tracr paper claiming they recreated the RASP function in Tracr). Instead of backfilling with BOS tokens it simply sorts all tokens in groups. The most-freq function (make_sort_freq) is also hardcoded to only accept 1 as the min_key value for some reason. I could fix this but it is not really a high priority (and seemingly breaks the sort function)

The most-freq function seems to fail sometimes (always?) when there are mutiple groups of the same count. Maybe they did not actually sort the output based on token grupings and only on frequency? Need to check. That apears to be the case. The Tracr make_sort_freq function is lazy and does not differentiate between tokens as long as the count is the same.

Shuffle dyck 
* The RASP paper uses the tokens T, P and F to account for if a dyck-k sequence is legal, possible legal or not legal for each token in the sequence. The Tracr implementation on the other hand only uses 1 or 0 to show if the entire sequence is legal or not. This is a far simpler solution yet for some reason they explicitly claim that this is how it is implemented in the RASP paper in their code ???
* If tokens are randomly selected most sequences will be unblanaced e.g. only even sequences can be balanced and if the sequence starts with a end token it wll be unblanced.
* I should probably try to generate the sequence such that the probability of a balanced sequence is roughly 50%

#### Analyzing weights

What do I need to look out for? All of these should probably be applied layerwise (for each matrix of weights) and globally
* Maximum/minimum values?
* Binary values?
* All same values?
* Percentage which is 0?

It seems like all of the "b" weights are zero vectors for the given models. As such I feel like I should mostly stay away from those vectors when adding noise and training

Many of the layer weights are zero. The layer weights are usually binary or ternary, very rarely do the layer assume more values than 3.

The percentage of values which are zero is usually between 90 and 100%.

#### Adding noise

Flipping a set amunt of bits will often have no effect. The influence of a flipped bit is heavily dependent on which bit is flipped. I cannot say what specific bits are highly influential though. This behaviour strikes me as odd since I would intuit that binary weights are done so for a reason, that is all weights should be relevant at some points.

Adding gaussian noise seems to give a better range of failure. The failed percentage increases "exponentially"ish with how large the noise is unlike bitflips which can cause large errors or no difference by flipping a single bit.

#### Training

Haiku is needlessly complicated. E.g. generating new sequences requires manually updating the rng_key each time instead of doing it within the functions themselves. Everything is wrapped with mutplie layers of functions and classes which makes it very difficult to keep track of what is what