In [1]:
import jax
import numpy as np
import matplotlib.pyplot as plt

# The default of float16 can lead to discrepancies between outputs of
# the compiled model and the RASP program.
jax.config.update('jax_default_matmul_precision', 'float32')

from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp

In [2]:
#Plotting functions taken from the Tracr repo 

#@title Plotting functions
def tidy_label(label, value_width=5):
  if ':' in label:
    label, value = label.split(':')
  else:
    value = ''
  return label + f":{value:>{value_width}}"


def add_residual_ticks(model, value_width=5, x=False, y=True):
  if y:
    plt.yticks(
            np.arange(len(model.residual_labels))+0.5, 
            [tidy_label(l, value_width=value_width)
              for l in model.residual_labels], 
            family='monospace',
            fontsize=20,
    )
  if x:
    plt.xticks(
            np.arange(len(model.residual_labels))+0.5, 
            [tidy_label(l, value_width=value_width)
              for l in model.residual_labels], 
            family='monospace',
            rotation=90,
            fontsize=20,
    )


def plot_computation_trace(model,
                           input_labels,
                           residuals_or_outputs,
                           add_input_layer=False,
                           figsize=(12, 9)):
  fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)
  value_width = max(map(len, map(str, input_labels))) + 1

  for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):
    plt.sca(ax)
    plt.pcolormesh(layer[0].T, vmin=0, vmax=1)
    if i == 0:
      add_residual_ticks(model, value_width=value_width)
    plt.xticks(
        np.arange(len(input_labels))+0.5,
        input_labels,
        rotation=90,
        fontsize=20,
    )
    if add_input_layer and i == 0:
      title = 'Input'
    else:
      layer_no = i - 1 if add_input_layer else i
      layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'
      title = f'{layer_type} {layer_no // 2 + 1}'
    plt.title(title, fontsize=20)


def plot_residuals_and_input(model, inputs, figsize=(12, 9)):
  """Applies model to inputs, and plots the residual stream at each layer."""
  model_out = model.apply(inputs)
  residuals = np.concatenate([model_out.input_embeddings[None, ...],
                              model_out.residuals], axis=0)
  plot_computation_trace(
      model=model,
      input_labels=inputs,
      residuals_or_outputs=residuals,
      add_input_layer=True,
      figsize=figsize)


def plot_layer_outputs(model, inputs, figsize=(12, 9)):
  """Applies model to inputs, and plots the outputs of each layer."""
  model_out = model.apply(inputs)
  plot_computation_trace(
      model=model,
      input_labels=inputs,
      residuals_or_outputs=model_out.layer_outputs,
      add_input_layer=False,
      figsize=figsize)

## Basic Testing

In [3]:
#Return the input in reverse
def reverse() -> rasp.SOp:
    length = lib.make_length()
    opp_index = length - rasp.indices - 1
    flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
    return rasp.Aggregate(flip, rasp.tokens)

reverse_model = compiling.compile_rasp_to_model(reverse(), {0,1,2,3,4}, 7, compiler_bos="BOS")

#Model length does not include bos token

In [4]:
reverese_output = reverse_model.apply(["BOS",1,4,3,0,0,0,1])
print(reverese_output.decoded)
#reverese_output.attn_logits
#reverese_output.input_embeddings
#reverese_output.layer_outputs
#reverese_output.residuals

['BOS', 1, 0, 0, 0, 3, 4, 1]


In [14]:
#Testing atoi and adding indices to input

#If atoi is used all inputs need to be str
def atoi() -> rasp.SOp:
    return rasp.Map(lambda x: ord(x)-ord("0"), rasp.tokens)

def addIndicesAtoi() -> rasp.SOp:
    return rasp.indices + atoi()

def addIndices() -> rasp.SOp:
    #return rasp.indices + atoi()
    return rasp.indices + rasp.tokens

addIndicesModel = compiling.compile_rasp_to_model(addIndicesAtoi(), {"0","1","2","3","4","+"}, 5, compiler_bos="BOS")
print(addIndicesModel.apply(["BOS", "4","3","1","2"]).decoded)
#Gives output as int even though it does not accept int as input
#Aka can give ouputs not specified in input domain (At least numeric)

addIndicesModel = compiling.compile_rasp_to_model(addIndices(), {1,2,3,4}, 5, compiler_bos="BOS")
print(addIndicesModel.apply(["BOS", 4,3,1,2]).decoded)

#addIndicesModel = compiling.compile_rasp_to_model(addIndices(), {"0","1","2","3","4","+",0}, 5, compiler_bos="BOS")    
#Will not compile due to to compiler failing to compare number with str in atoi (during feasibility check?)

['BOS', 4, 4, 3, 5]
['BOS', 4, 4, 3, 5]


## Trying to create addition model based on my RASPy implementation

In [22]:
#Since the default does not work in a useful way I made my own version
def defaultAggregate(sel: rasp.Selector, value: rasp.SOp, default="_") -> rasp.SOp:
    length = lib.make_length().named("length")
    ones = rasp.numerical(rasp.Map(lambda x: 1, length)).named("ones")
    aggInd = rasp.numerical(rasp.Aggregate(sel, ones, default=0)).named("aggInd")
    #Need to be categorical as SequenceMap only need both SOp to be same class
    #There seem to be major misscalculations in the actual value of aggInd since x>0.1 gives false results
    #   will likely lead to major bugs at some point (:
    aggIndCat = rasp.Map(lambda x: x > 0.5, aggInd).named("aggIndCat")
    baseAgg = rasp.Aggregate(sel, value).named("baseAgg")
    return rasp.SequenceMap(lambda x,y: y if x == True else default, aggIndCat, baseAgg)

def shift(i=1, default="_") -> rasp.SOp:
    indShift = (rasp.indices + i).named("indShift")
    shiftMask = rasp.Select(indShift, rasp.indices, rasp.Comparison.EQ).named("shiftMask")  
    #Implement manual default since Tracr does not seem to allow custom default tokens
    return defaultAggregate(shiftMask, rasp.tokens).named("shift")

    #A more efficient version which they use in rasp.lib
    #Circumvents the extra mlp (indShift) by adding directly in the Select function logic
    select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
                                        lambda k, q: q == k + offset)
    out = rasp.Aggregate(select_off_by_offset, sop, default=None)
    return out.named(f"shift_by({offset})")

#Testing
shiftModel = compiling.compile_rasp_to_model(shift(2), {"0","1","2","3","4","+","_"}, 5, compiler_bos="BOS")
#shiftModel = compiling.compile_rasp_to_model(shift(2), {"0","1","2","3","4","+"}, 5, compiler_bos="BOS") #default,"_", needs to be included in the allowed input in order to be encoded correctly
model = shiftModel
input = ["BOS","1","3","+","4","2"]
print(model.apply(input).decoded)

#plot_residuals_and_input(model=model, inputs=input, figsize=(10, 15))

#plot_layer_outputs(model=model, inputs=input, figsize=(10, 9))

['BOS', '+', '+', '1', '3', '+']


In [7]:

#detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T]

tempModel = compiling.compile_rasp_to_model(lib.make_frac_prevs(rasp.tokens=="b"), {"a", "b", "c","True","False"}, 7, compiler_bos="BOS")
model = tempModel
input = ["BOS", "a", "b", "c", "a", "b", "c"]
print(model.apply(input).decoded)

['BOS', 6.676483943880793e-15, 0.5, 0.3333333432674408, 0.25, 0.4000000059604645, 0.3333333432674408]


In [8]:
#cumulative sum of a boolean input
def cumsum(boolean: rasp.SOp) -> rasp.SOp:
    ints = rasp.numerical(boolean)
    cumsumSel = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ).named("cumsumSel")
    #return rasp.numerical(rasp.Map(lambda x: x, sop))
    #return rasp.numerical(rasp.Aggregate(cumsumSel, ints, default=0))
    cumsumFrac = rasp.numerical(rasp.Aggregate(cumsumSel, ints, default=0)).named("cumsumFrac")
    cumsumFracCat = rasp.Map(lambda x: x, cumsumFrac)       #This warps the results severely, seems to round down but unclear how much
    return cumsumFracCat
    return cumsumFracCat * (1 + rasp.indices)
    return rasp.SequenceMap(lambda x, y: x * y, cumsumFracCat, rasp.Map(lambda x: x+1,rasp.indices))
    return (cumsumFrac * (rasp.indices+1)).named("cumsum")

#cumsumModel = compiling.compile_rasp_to_model(cumsum(), {"0","1","2","3","4","T","F"}, 5, compiler_bos="BOS")
cumsumModel = compiling.compile_rasp_to_model(cumsum(rasp.tokens == "<"), {"<","x","h","y","l"}, 8, compiler_bos="BOS")
model = cumsumModel
input = ["BOS", "<","<","x","h","y","<","<","l"]
print(model.apply(input).decoded)

#plot_residuals_and_input(model=model, inputs=input, figsize=(10, 15))

['BOS', 1.0, 1.0, 0.5, 0.5, 0.3333333333333333, 0.5, 0.5, 0.5]


In [9]:
def test() -> rasp.SOp:
    sel = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    input = rasp.numerical(rasp.tokens*1)
    return rasp.numerical(rasp.Aggregate(sel, rasp.numerical(rasp.indices*10), default=0))

tempModel = compiling.compile_rasp_to_model(test(), {0,1,2,3,4,5}, 5, compiler_bos="BOS")
model = tempModel
input = ["BOS", 3,1,2,4,1]
print(model.apply(input).decoded)

NotImplementedError: ('Attention patterns can currently only average binary variables. Not:', {0, 40, 10, 20, 30})