In [1]:
import jax
import numpy as np
import matplotlib.pyplot as plt

# The default of float16 can lead to discrepancies between outputs of
# the compiled model and the RASP program.
jax.config.update('jax_default_matmul_precision', 'float32')

from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp

import sys 
sys.path.append("src")

%load_ext autoreload
%autoreload 2

RuntimeError: jaxlib version 0.4.26 is newer than and incompatible with jax version 0.4.21. Please update your jax and/or jaxlib packages.

In [2]:
from src.functions import *
from src.Model import Model

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

In [5]:
#Check that data can be generated
data = generateData("shuffle_dyck2", 5, 100)
print(data[:5])

[(['BOS', '{', '}'], ['BOS', 1, 1]), (['BOS', '(', ')'], ['BOS', 1, 1]), (['BOS', ')', '(', '(', '{'], ['BOS', 0, 0, 0, 0]), (['BOS', '{', '}', '{', '}'], ['BOS', 1, 1, 1, 1]), (['BOS', '(', ')', '{', '('], ['BOS', 0, 0, 0, 0])]


In [6]:
#Print some statistics on the dyck data to check for balancing
print("dyck1")
checkDyckBalance(generateData("shuffle_dyck1", 5, 10000))
checkDyckBalance(generateData("shuffle_dyck1", 10, 10000))
checkDyckBalance(generateData("shuffle_dyck1", 15, 10000))
checkDyckBalance(generateData("shuffle_dyck1", 50, 10000))

print("\ndyck2")
checkDyckBalance(generateData("shuffle_dyck2", 5, 10000))
checkDyckBalance(generateData("shuffle_dyck2", 10, 10000))
checkDyckBalance(generateData("shuffle_dyck2", 15, 10000))
checkDyckBalance(generateData("shuffle_dyck2", 50, 10000))

#Seems to work fairly well, roughly between 40 and 50% is balanced depending on the maximum size

dyck1
Percentage of data which is:
Of odd length: 12.55
Balanced: 42.95
Percentage of data which is:
Of odd length: 8.89
Balanced: 48.83
Percentage of data which is:
Of odd length: 12.88
Balanced: 47.76
Percentage of data which is:
Of odd length: 11.6
Balanced: 52.58

dyck2
Percentage of data which is:
Of odd length: 12.46
Balanced: 41.56
Percentage of data which is:
Of odd length: 8.45
Balanced: 48.38
Percentage of data which is:
Of odd length: 12.15
Balanced: 47.63
Percentage of data which is:
Of odd length: 11.46
Balanced: 52.11


In [7]:
#Making sure the entire pipeline for testing the model works
name = "shuffle_dyck2"
maxSeqLen = 5
data = generateData(name, maxSeqLen, 1000)
model = generateModel(name, maxSeqLen)

print(data[:5])

booleanAccuracy = model.evaluateModel(data)
accuracy=np.mean(booleanAccuracy)
print("Accuracy:",accuracy)

[(['BOS', '(', ')'], ['BOS', 1, 1]), (['BOS', '}', '{'], ['BOS', 0, 0]), (['BOS', '(', ')', '}', '{'], ['BOS', 0, 0, 0, 0]), (['BOS', '{', '}', '(', '{'], ['BOS', 0, 0, 0, 0]), (['BOS', '}', '('], ['BOS', 0, 0])]
Evaluating model: shuffle_dyck2
Accuracy: 1.0


In [8]:
#How to look at specific data points
print(np.argwhere(booleanAccuracy-1))   #Numpy list where evaluation failed
print(data[7])

print(model.apply(data[7][0]))

[]
(['BOS', '{', '}'], ['BOS', 1, 1])
['BOS', True, True]


In [9]:
#Quick function to check for if all "b" weights are truly zero
def analyzeB(model: Model):
    for name1, layer in model.model.params.items():
        for name2, weight in layer.items():
            if name2!="b":
                continue
            weightCounter = {}

            #Find unique weights and count instances for the weights
            for t in weight.flatten():
                t = float(t)
                if t in weightCounter:
                    weightCounter[t]+=1
                else:
                    weightCounter[t]=1

            from src.Model import calculateWeightStatistics

            calculateWeightStatistics(weightCounter, True)

#name = "reverse"
name = "reverse"
maxSeqLen = 5
data = generateData(name, maxSeqLen, 1000)
model = generateModel(name, maxSeqLen)

#Display weight layer statistics of the model
#analyzeB(model)
model.updateWeightStatistics()
model.printWeightStatistics()

TransformerConfig(num_heads=1, num_layers=4, key_size=12, mlp_hidden_size=30, dropout_rate=0.0, activation_function=<jax._src.custom_derivatives.custom_jvp object at 0x0000023449621CA0>, layer_norm=False, causal=False)

Layer analysis:
pos_embed
	 embeddings
	  N: 270	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 98.15
token_embed
	 embeddings
	  N: 315	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 95.56
transformer/layer_0/attn/key
	 w
	  N: 540	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 98.89
transformer/layer_0/attn/linear
	 w
	  N: 540	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 99.81
transformer/layer_0/attn/query
	 w
	  N: 540	 min/max: 0.00/100.00	 nValues: 2	 percentageZero: 95.19
transformer/layer_0/attn/value
	 w
	  N: 540	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 99.81
transformer/layer_0/mlp/linear_1
	 w
	  N: 1350	 min/max: -75.00/100.00	 nValues: 13	 percentageZero: 98.37
transformer/layer_0/mlp/linear_2
	 w
	  N: 1350	 min/max: -1.00/1.00	 nValues

In [10]:
#Testing adding noise 
name = "reverse"
maxSeqLen = 5
model = generateModel(name, maxSeqLen)
model.addNoise(noiseType="gaussian", amount=1.0, param=0.001)
model.updateWeightStatistics()
model.printWeightStatistics()

TransformerConfig(num_heads=1, num_layers=4, key_size=12, mlp_hidden_size=30, dropout_rate=0.0, activation_function=<jax._src.custom_derivatives.custom_jvp object at 0x0000023449621CA0>, layer_norm=False, causal=False)

Layer analysis:
pos_embed
	 embeddings
	  N: 270	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 98.15
token_embed
	 embeddings
	  N: 315	 min/max: 0.00/1.00	 nValues: 2	 percentageZero: 95.56
transformer/layer_0/attn/key
	 w
	  N: 540	 min/max: -0.00/1.00	 nValues: 540	 percentageZero: 0.00
transformer/layer_0/attn/linear
	 w
	  N: 540	 min/max: -0.00/1.00	 nValues: 540	 percentageZero: 0.00
transformer/layer_0/attn/query
	 w
	  N: 540	 min/max: -0.00/100.00	 nValues: 540	 percentageZero: 0.00
transformer/layer_0/attn/value
	 w
	  N: 540	 min/max: -0.00/1.00	 nValues: 540	 percentageZero: 0.00
transformer/layer_0/mlp/linear_1
	 w
	  N: 1350	 min/max: -75.00/100.00	 nValues: 1350	 percentageZero: 0.00
transformer/layer_0/mlp/linear_2
	 w
	  N: 1350	 min/max: -1.00/1.00

In [11]:
#Testing evaluating model after adding noise
data = generateData(name, maxSeqLen, 1000)
booleanAccuracy = model.evaluateModel(data)
accuracy=np.mean(booleanAccuracy)
print("Accuracy:",accuracy)

Evaluating model: reverse
Accuracy: 0.991


## Training

In [21]:
#Trying to figure out what kind of haiku model the tracr models are (how do they relate to pure haiku models) and how I can train these models

#A non-stochastic simple haiku model
import haiku as hk
import jax.numpy as jnp

class MyLinear1(hk.Module):

    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, x):
        j, k = x.shape[-1], self.output_size
        w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
        w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
        b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
        return jnp.dot(x, w) + b

def _forward_fn_linear1(x):
    module = MyLinear1(output_size=2)
    return module(x)

haikuModel = hk.without_apply_rng(hk.transform(_forward_fn_linear1))

#Tracr model
tracrModel = generateModel("sort", 5).model

tracrModel.get_compiled_model

print(haikuModel.apply)
print("--------")
print(tracrModel.forward)
print("--------")

<function without_apply_rng.<locals>.apply_fn at 0x00000234649C8220>
--------
<function without_apply_rng.<locals>.apply_fn at 0x00000234649C87C0>
--------


## Notes

#### Testing the base functions and generating the test data

The sort function does not have a 100% accuracy. This only seems to apply when including input token 0, if only using 1 and up it seems to work. A cursory analysis would suggest that the min value in the sort function is multiplied with the indicies which makes it indistinguishable if the minimum value is 1.

The most-freq function does not work in the same method as the original RASP paper (despite the Tracr paper claiming they recreated the RASP function in Tracr). Instead of backfilling with BOS tokens it simply sorts all tokens in groups. The most-freq function (make_sort_freq) is also hardcoded to only accept 1 as the min_key value for some reason. I could fix this but it is not really a high priority (and seemingly breaks the sort function)

The most-freq function seems to fail sometimes (always?) when there are mutiple groups of the same count. Maybe they did not actually sort the output based on token grupings and only on frequency? Need to check. That apears to be the case. The Tracr make_sort_freq function is lazy and does not differentiate between tokens as long as the count is the same.

Shuffle dyck 
* The RASP paper uses the tokens T, P and F to account for if a dyck-k sequence is legal, possible legal or not legal for each token in the sequence. The Tracr implementation on the other hand only uses 1 or 0 to show if the entire sequence is legal or not. This is a far simpler solution yet for some reason they explicitly claim that this is how it is implemented in the RASP paper in their code ???
* If tokens are randomly selected most sequences will be unblanaced e.g. only even sequences can be balanced and if the sequence starts with a end token it wll be unblanced.
* I should probably try to generate the sequence such that the probability of a balanced sequence is roughly 50%

#### Analyzing weights

What do I need to look out for? All of these should probably be applied layerwise (for each matrix of weights) and globally
* Maximum/minimum values?
* Binary values?
* All same values?
* Percentage which is 0?

It seems like all of the "b" weights are zero vectors for the given models. As such I feel like I should mostly stay away from those vectors when adding noise and training

Many of the layer weights are zero. The layer weights are usually binary or ternary, very rarely do the layer assume more values than 3.

The percentage of values which are zero is usually between 90 and 100%.

#### Adding noise

Flipping a set amunt of bits will often have no effect. The influence of a flipped bit is heavily dependent on which bit is flipped. I cannot say what specific bits are highly influential though. This behaviour strikes me as odd since I would intuit that binary weights are done so for a reason, that is all weights should be relevant at some points.

Adding gaussian noise seems to give a better range of failure. The failed percentage increases "exponentially"ish with how large the noise is unlike bitflips which can cause large errors or no difference by flipping a single bit.

#### Training

Haiku is needlessly complicated. E.g. generating new sequences requires manually updating the rng_key each time instead of doing it within the functions themselves. Everything is wrapped with mutplie layers of functions and classes which makes it very difficult to keep track of what is what

The tracr models are of the AssembledTransformerModel class which can be found in the tracr directory under tracr/tracr/compiler/assemble.py. The class seems to be a wrapper which contains things such as the parameters, configuration parameters and the forward pass function of a haiku model

I cannot seem to find a convenient method to train a haiku model (like for example how a sklearn model assumes you want to train it from data). As such I think I need to figure out how a transformer is trained and manually apply that training to the parameters in a training loop. 

    Note, the haiku documentation for gradients states that "You only need this in a very specific case that you want to take a gradient inside a transform()ed function and the function you are differentiating uses set_state()". I am not sure what exactly this entails but it sound like they expect you to not use the grad function usually even though they do use the grad function in their "training a subset of parameters" example

Quickly look into the VectorQuantisizer as it return some loss for optimizer, whatever that means. Referenced paper suggests it is used for training

Look into the optax library. Seems to be perfect for what I want. At least if "https://github.com/google-deepmind/dm-haiku/blob/main/examples/transformer/train.py" is anything to go by

If I use softmax I run the risk of trying to train even on the samples which are correct. No clue if this is a problem or not though lol