In [1]:
import torch
def multiplication_mod_p_data(p, eq_token, op_token):
    """x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p
    """
    x = torch.arange(p)
    y = torch.arange(1, p)
    x, y = torch.cartesian_prod(x, y).T

    eq = torch.ones_like(x) * eq_token
    op = torch.ones_like(x) * op_token
    result = x * y % p

    # "All of our experiments used a small transformer trained on datasets of
    # equations of the form a◦b = c, where each of “a”, “◦”, “b”, “=”, and “c”
    # is a seperate token"
    return torch.stack([x, op, y, eq, result])

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--p", type=int, default=97)
args = parser.parse_args('')

eq_token = args.p
op_token = args.p + 1

data = multiplication_mod_p_data(args.p, eq_token, op_token)

train_idx, valid_idx = torch.randperm(data.shape[1]).split(data.shape[1] // 2)
train_data, valid_data = data[:, train_idx], data[:, valid_idx]

In [2]:
from main import Block
import torch.nn as nn
class Decoder(nn.Module):
    """Causal Transformer decoder
    """

    def __init__(self, dim=128, num_layers=2, num_heads=4, num_tokens=97, seq_len=5):
        super().__init__()
        self.token_embeddings = nn.Embedding(num_tokens, dim)
        self.position_embeddings = nn.Embedding(seq_len, dim)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Block(dim, num_heads))

        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_tokens, bias=False)

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        x = x.transpose(1, 0)
        h = self.token_embeddings(x)
        positions = torch.arange(x.shape[0], device=x.device).unsqueeze(-1)
        h = h + self.position_embeddings(positions).expand_as(h)
        for layer in self.layers:
            h = layer(h)

        h = self.ln_f(h)
        logits = self.head(h)

        logits = logits.transpose(1, 0)
        return logits

In [3]:
from torch.utils.data import Dataset, DataLoader
# Algorithmic data
dataset_train = train_data.T
dataset_test = valid_data.T
batch_size = 32  # Adjust as needed
    
# Create a DataLoader with the custom sampler and custom collate function
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)


In [4]:
next(iter(train_loader)).shape

torch.Size([32, 5])

In [5]:
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
model = Decoder(
    dim=128, num_layers=2, num_heads=4, num_tokens=args.p + 2, seq_len=5
).to(device)


In [6]:
# import torch
# results = torch.load('results/res_test_none.pt')
# # results = torch.load('results/res_test_ma_w100_l5_wd10e-02.pt')
# # results.keys()

In [7]:
# weights = results['net']
# model.load_state_dict(weights[0])

In [8]:
import torch
import torch.nn.functional as F
from torch_influence import BaseObjective

class MyObjective(BaseObjective):

    def train_outputs(self, model, batch):
        print("train_outputs batch shape", batch.shape)
        return model(batch[:,:-1])

    def train_loss_on_outputs(self, outputs, batch):
        print("train_loss_on_outputs outputs shape", outputs.shape)
        print("batch shape", batch.shape)
        return F.cross_entropy(outputs[:, -1], batch[:,-1])  # mean reduction required

    def train_regularization(self, params):
        return 0.01 * torch.square(params.norm())

    # training loss by default taken to be 
    # train_loss_on_outputs + train_regularization

    def test_loss(self, model, params, batch):
        print("test_loss batch shape", batch.shape)
        return F.cross_entropy(model(batch[:, :-1]), batch[:,-1])  # no regularization in test loss

In [9]:
train_batch = next(iter(train_loader)).to(device)
train_batch.shape

torch.Size([32, 5])

In [10]:
train_batch = next(iter(test_loader)).to(device)
train_batch.shape

torch.Size([32, 5])

In [11]:
model(train_batch)

tensor([[[-0.3127,  0.1468,  0.9827,  ...,  0.1660, -0.9748, -0.7163],
         [-0.0713, -0.2671, -0.5012,  ..., -0.2699,  0.3985, -0.2814],
         [ 0.4764, -0.6265,  0.9626,  ...,  0.0983,  0.8717, -0.6418],
         [-0.9972, -0.4798, -0.3507,  ..., -0.1703, -0.5253, -1.4174],
         [ 0.5469,  0.6592, -0.1290,  ...,  0.1887, -0.3271, -0.2404]],

        [[-0.1489, -0.1876,  1.3166,  ...,  0.3554, -0.6584,  0.2875],
         [ 0.0433, -0.4908, -0.5456,  ..., -0.2692,  0.1652, -0.0075],
         [ 0.2517, -0.9094, -0.3767,  ..., -0.7832,  0.3732, -0.6379],
         [-0.8622, -0.4794, -0.4884,  ..., -0.2298, -0.5856, -1.2069],
         [ 0.1768,  0.5179, -0.1725,  ..., -0.3211, -0.7126,  0.4036]],

        [[-0.1489, -0.1876,  1.3166,  ...,  0.3554, -0.6584,  0.2875],
         [ 0.0433, -0.4908, -0.5456,  ..., -0.2692,  0.1652, -0.0075],
         [-0.0018, -0.1546,  0.0510,  ..., -0.0280,  0.2576,  0.4686],
         [-0.8472, -0.4028, -0.5288,  ..., -0.1719, -0.6551, -1.1679],
  

In [12]:
assert False

AssertionError: 

In [13]:
from torch_influence import AutogradInfluenceModule
   
module = AutogradInfluenceModule(
    model=model,
    objective=MyObjective(),  
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    damp=0.001
)


train_outputs batch shape torch.Size([32, 5])
train_loss_on_outputs outputs shape torch.Size([32, 4, 99])
batch shape torch.Size([32, 5])


RuntimeError: derivative for aten::_scaled_dot_product_flash_attention_for_cpu_backward is not implemented

In [None]:
# influence scores of training points 1, 2, and 3 on test point 0
scores = module.influences([1, 2, 3], [0])

NameError: name 'module' is not defined