In [1]:
from torch import nn
from associative_recall import _gap_power_distr_ar
import torch
import numpy as np
from einops import rearrange
import matplotlib.pyplot as plt

import math
from torch.fft import rfft, irfft


  from .autonotebook import tqdm as notebook_tqdm


### Getting MQAR data

In [2]:
vocab_size = 128
input_seq_len=64
num_examples=5
num_kv_pairs=8
model_dimension = input_seq_len * vocab_size
context_size = num_kv_pairs * 2

data = _gap_power_distr_ar(
    vocab_size=vocab_size,
    input_seq_len=input_seq_len,
    num_examples=num_examples,
    num_kv_pairs=num_kv_pairs,
    seed=123
)
inputs = data[0] 
targets = data[1]

print(inputs[0])
print(targets[0])

# Understanding this data -- -100's are not scored. Notice in the first portion of the sequence, there are kv pairs (25, 91) (40, 87) etc. and then later we see in the targets a "91" and a "87". At the corresponding positions in the input, we see "25" followed by a random token "22" and "99" followed by a random token "40". The model needs to predict the non- -100 tokens in the target sequence.

tensor([ 25,  91,  40,  87,  16,  89,  32,  95,  44, 121,  55,  73,  19,  86,
         41, 123,  50,  16,  43,  25,  12,  40,   8,  78,  66,   6, 105,  44,
         60,  41,  48,  55,  69,  78,  58, 101,  24, 112,  64, 127, 114,   5,
         73,   4,  17, 101,  63,  46,   7,  49,  48,   9, 106,  19,  58,  25,
         84,  32,  52, 104,   0, 104,  61,  64])
tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100,   89, -100,   91, -100,   87, -100, -100,
        -100, -100, -100,  121, -100,  123, -100,   73, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100,   86, -100, -100, -100,   95, -100, -100,
        -100, -100, -100, -100])


### A sample solution to MQAR with autocorrelation

In [9]:
# Scoring Function
def run_check(y, label, embeddings, input):
    preds = (embeddings.weight.data @ y[0]).argmax(dim=0)
    print(f"{input=}")
    print(f"{preds=}")
    print(f"{label=}")
    result = (label == preds)[label != -100].to(float).mean()
    return result

In [8]:
def fft_convolution(x: torch.Tensor, k: torch.Tensor):
    """
    Args:
        x: (d, l)
        k: (d, l)
    """
    seqlen = x.shape[-1]
    fft_size = seqlen * 2
    x_f = torch.fft.rfft(x, n=fft_size)
    k_f = torch.fft.rfft(k, n=fft_size)
    y_f = x_f * k_f
    y = torch.fft.irfft(y_f, n=fft_size)[..., :seqlen]
    return y

def fft_autocorrelation(q: torch.Tensor, k: torch.Tensor=None):
    if k is None:
        k = q
    fft_len = q.shape[-1] * 2
    return irfft(rfft(q, n=fft_len)  * torch.conj(rfft(k, n=fft_len)), n=fft_len)[..., :q.shape[-1]]

def get_top_k(q, top_k=3):
    z = fft_autocorrelation(q, q)
    z[..., 0] = 0  # this is back to when the sequence is the same (What does this line do?)
    z = torch.sum(z, dim=1)
    _, indices = torch.topk(z, k=top_k, sorted=False, dim=-1)
    return indices

In [19]:
# example inputs
examples = [
    [torch.tensor([7, 3, 6, 1, 6, 1, 7, 3, 4, 0, 6, 14]), 1],   # 1 is th next-token prediction for "6"
    [torch.tensor([4, 0, 6, 1, 4, 0, 7, 3, 6, 2]), 1],
    [torch.tensor([6, 1, 7, 2, 4, 0, 9, 9, 9, 7, 2, 6, 1, 4, 7,]), 0],
    [torch.tensor([4, 2, 6, 1, 7, 3, 6, 8]), 1],
]

score = 0
for ex in examples:
    input, label = ex
    label = torch.tensor([-100]*(len(input)-2) + [label, -100])
    
    assert len(input) == len(label)
    length, batch, dim, top_k = len(input), 1, 16, 2
    vocab_size = max(input.tolist() + label.tolist()) + 1
    torch.manual_seed(0)

    # embed the input
    embeddings = nn.Embedding(vocab_size, dim)
    x = embeddings(input).T
    batch_x = x.unsqueeze(0).repeat(batch, 1, 1)
    q, v, z = batch_x, batch_x, batch_x

    # compute autocorrelation and get the top-k gaps that we want to shift for
    indices = get_top_k(q, top_k=top_k)

    # construct filters for each of the top-k gaps to pull keys and values forward
    filters_keys = []
    filters_values = []
    for j in range(indices.shape[-1]):
        filter_value = torch.zeros(batch, length).to(device=q.device, dtype=q.dtype)
        filter_key = torch.zeros(batch, length).to(device=q.device, dtype=q.dtype)
        filter_key[torch.arange(batch), indices[:, j]] = 1            # conv to pull forward keys
        filter_value[torch.arange(batch), indices[:, j]-1] = 1       # conv to pull forward values
        filters_keys.append(filter_key)
        filters_values.append(filter_value)
    key_filt = torch.stack(filters_keys, dim=1)
    value_filt = torch.stack(filters_values, dim=1)

    # apply the filters
    keys = fft_convolution(                                                     
        rearrange(q, "b d l -> b d 1 l"), 
        rearrange(key_filt, "b h l -> b 1 h l")
    )
    values = fft_convolution(                                                     
        rearrange(q, "b d l -> b d 1 l"), 
        rearrange(value_filt, "b h l -> b 1 h l")
    )

    # apply a mask to select which gaps map to a particular query token
    mask = torch.softmax(torch.einsum("b d l, b d h l -> b h l", q, keys), dim=1)
    y = torch.einsum("b d h l, b h l -> b d l", values, mask)     

    # applies an "argmax" lm head and determines whether the "mlm" prediction for the target token is correct                            
    result = run_check(y, label, embeddings, input)
    print(f"{result=}\n")
    score += result

# overall averaged score
print(f"{score/len(examples)=}")



input=tensor([ 7,  3,  6,  1,  6,  1,  7,  3,  4,  0,  6, 14])
preds=tensor([3, 6, 7, 3, 6, 1, 3, 6, 1, 3, 1, 0])
label=tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,    1, -100])
result=tensor(1., dtype=torch.float64)

input=tensor([4, 0, 6, 1, 4, 0, 7, 3, 6, 2])
preds=tensor([6, 6, 2, 4, 0, 6, 0, 6, 1, 4])
label=tensor([-100, -100, -100, -100, -100, -100, -100, -100,    1, -100])
result=tensor(1., dtype=torch.float64)

input=tensor([6, 1, 7, 2, 4, 0, 9, 9, 9, 7, 2, 6, 1, 4, 7])
preds=tensor([7, 2, 7, 1, 6, 2, 5, 3, 6, 1, 7, 1, 7, 0, 9])
label=tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100,    0, -100])
result=tensor(1., dtype=torch.float64)

input=tensor([4, 2, 6, 1, 7, 3, 6, 8])
preds=tensor([4, 2, 6, 1, 7, 3, 1, 7])
label=tensor([-100, -100, -100, -100, -100, -100,    1, -100])
result=tensor(1., dtype=torch.float64)

score/len(examples)=tensor(1., dtype=torch.float64)
