In [1]:
from math import sqrt

import torch
from torch import Tensor
from torch.nn import Module, ModuleList, Linear, LayerNorm, MultiheadAttention
import lightning.pytorch as pl

In [166]:
# class MyMultiheadAttention(Module):
#     # start from here: https://nn.labml.ai/transformers/mha.html#section-5

class LinearformerLayer(Module):
    # TODO: extremely frustrating but they require the dimension of the token representations to be divisible by the number of heads
    def __init__(self, d_model, nhead,
                 batch_first=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, bias=False, batch_first=batch_first, **factory_kwargs)
        self.linear = Linear(d_model, d_model, bias=False, **factory_kwargs)

    def forward(self, x):
        # N.B.: This assumes that all sequences in the batch have the same length
        # if they are not, we will need to use `key_padding_mask`
        x = self.self_attn(x, x, x, need_weights=False)[0]
        x = self.linear(x)
        return x

class LinearFormer(Module):
    def __init__(self, num_layers, d_model, nhead, 
                 batch_first=True, device=None, dtype=None):
        super().__init__()
        self.layers = torch.nn.Sequential(*(LinearformerLayer(
            d_model, nhead, batch_first=batch_first, device=device, dtype=dtype) for _ in range(num_layers)))
        # self.layers = ModuleList([LinearformerLayer(
        #     d_model, nhead, batch_first=batch_first, device=device, dtype=dtype) for _ in range(num_layers)])

    def forward(self, x):
        return self.layers(x)

    @classmethod
    def farthest(cls, d_model, batch_first=True, device=None, dtype=None):
        model = cls(1, d_model, 1, batch_first=batch_first, device=device, dtype=dtype)
        model.layers[0].linear.weight.data = torch.eye(*model.layers[0].linear.weight.data.shape)
        if model.layers[0].linear.bias is not None:
            model.layers[0].linear.bias.data = torch.zeros_like(model.layers[0].linear.bias)
        model.layers[0].self_attn.in_proj_weight.data = torch.concat([torch.eye(d_model), -1000 * torch.eye(d_model), torch.eye(d_model)])
        if model.layers[0].self_attn.in_proj_bias is not None:
            model.layers[0].self_attn.in_proj_bias.data = torch.zeros_like(model.layers[0].self_attn.in_proj_bias)
        model.layers[0].self_attn.out_proj.weight.data = torch.eye(*model.layers[0].self_attn.out_proj.weight.shape)
        if model.layers[0].self_attn.out_proj.bias:
            model.layers[0].self_attn.out_proj.bias.data = torch.zeros_like(model.layers[0].self_attn.out_proj.bias)
        return model

    @classmethod
    def farthest_perturbed(cls, d_model, batch_first=True, device=None, dtype=None):
        scale = .01
        model = cls(1, d_model, 1, batch_first=batch_first, device=device, dtype=dtype)
        model.layers[0].linear.weight.data = torch.eye(*model.layers[0].linear.weight.data.shape) + scale * torch.randn_like(model.layers[0].linear.weight.data)
        if model.layers[0].linear.bias is not None:
            model.layers[0].linear.bias.data = torch.zeros_like(model.layers[0].linear.bias) + scale * torch.randn_like(model.layers[0].linear.bias)
        model.layers[0].self_attn.in_proj_weight.data = torch.concat([torch.eye(d_model), -1000 * torch.eye(d_model), torch.eye(d_model)]) + scale * torch.randn_like(model.layers[0].self_attn.in_proj_weight.data)
        if model.layers[0].self_attn.in_proj_bias is not None:
            model.layers[0].self_attn.in_proj_bias.data = torch.zeros_like(model.layers[0].self_attn.in_proj_bias) + scale * torch.randn_like(model.layers[0].self_attn.in_proj_bias.data)
        model.layers[0].self_attn.out_proj.weight.data = torch.eye(*model.layers[0].self_attn.out_proj.weight.shape) + scale * torch.randn_like(model.layers[0].self_attn.out_proj.weight.data)
        if model.layers[0].self_attn.out_proj.bias:
            model.layers[0].self_attn.out_proj.bias.data = torch.zeros_like(model.layers[0].self_attn.out_proj.bias) + scale * torch.randn_like(model.layers[0].self_attn.out_proj.bias.data)
        return model

class MergedLinearFormer(Module):
    def __init__(self, d_model, batch_first=True, device=None, dtype=None):
        super().__init__()
        self.QK = torch.nn.Parameter(torch.randn((d_model, d_model), device=device, dtype=dtype))
        self.VO = torch.nn.Parameter(torch.randn((d_model, d_model), device=device, dtype=dtype))

    def forward(self, x):
        batch_size, ntokens, dim = x.shape
        # attn_matrix is batch_size, ntokens, ntokens
        attn_matrix = torch.nn.Softmax(dim=2)(torch.einsum("btd,de,bue->btu", x, self.QK, x) / sqrt(dim))
        return torch.einsum("btu,bud,de->bte", attn_matrix, x, self.VO)

    @classmethod
    def extract_attn(cls, my_layer):
        dim = layer0.self_attn.out_proj.weight.data.shape[0]
        merged = cls(dim)
        merged.QK.data = my_layer.self_attn.in_proj_weight.data[:dim, :].T @ my_layer.self_attn.in_proj_weight.data[dim:(2*dim), :]
        merged.VO.data = my_layer.self_attn.in_proj_weight.data[(2*dim):, :].T @ my_layer.self_attn.out_proj.weight.data.T @ my_layer.linear.weight.T
        return merged


In [3]:
def gen_sentence(ntokens, dim, rng=None):
    x = torch.randn((ntokens, dim), generator=rng)
    # normalize each row
    return x / x.norm(dim=1).reshape(-1, 1)

def label_farthest(sentence):
    distances = 2 - 2 * sentence @ sentence.T
    farthests = distances.argmax(dim=0)
    return sentence[farthests, :]

In [4]:
class FarthestPointDataset(torch.utils.data.IterableDataset):
    # TODO: let ntokens vary according to some reasonable distribution
    def __init__(self, ntokens, dim, seed=None):
        super().__init__()
        self.ntokens = ntokens
        self.dim = dim
        self.seed = seed

    def __iter__(self):
        rng = torch.Generator()
        worker_info = torch.utils.data.get_worker_info()
        worker_id = 0 if worker_info is None else worker_info.id
        seed = rng.seed() if self.seed is None else self.seed
        rng.manual_seed(seed + worker_id)

        while True:
            sentence = gen_sentence(self.ntokens, self.dim, rng=rng)
            yield sentence, label_farthest(sentence)


In [5]:
class LitSequenceRegression(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        y_hat = self.model(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        # return optimizer
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9, verbose=False)
        return [optimizer], [lr_scheduler]

In [125]:
ntokens = 10
dim = 4
num_layers = 1
nhead = 1
batch_size = 64
lit_model = LitSequenceRegression(LinearFormer(num_layers, dim, nhead))
# lit_model = LitSequenceRegression(LinearFormer.farthest_perturbed(dim))
# lit_model = LitSequenceRegression(LinearFormer.farthest(dim))
data = torch.utils.data.DataLoader(FarthestPointDataset(ntokens, dim), batch_size=batch_size, num_workers=0)

In [7]:
# trainer = pl.Trainer(limit_train_batches=100, max_epochs=200)
# trainer.fit(model=lit_model, train_dataloaders=data)

In [208]:
# x = gen_sentence(ntokens, dim)
# y = label_farthest(x)
# yhat = lit_model.model(x)

In [8]:
dl = torch.utils.data.DataLoader(FarthestPointDataset(ntokens=3, dim=4), batch_size=2, num_workers=0)
x, y = next(iter(dl))
yhat = lit_model.model(x)
torch.nn.functional.mse_loss(y, yhat)

tensor(0.2501, grad_fn=<MseLossBackward0>)

In [92]:
# print(x)
# print(y)
# print(yhat)

In [216]:
# list(lit_model.model.named_parameters())

In [148]:
layer0 = lit_model.model.layers[0]
# layer0 = LinearFormer.farthest(dim).layers[0]
list(layer0.named_parameters())

[('self_attn.in_proj_weight',
  Parameter containing:
  tensor([[-0.3483,  0.1091,  0.1221,  0.0469],
          [-0.1485, -0.4305,  0.5641,  0.3172],
          [-0.2736,  0.3633, -0.5100, -0.4738],
          [ 0.1057, -0.4365, -0.3454, -0.5276],
          [ 0.3811,  0.1482,  0.3414, -0.3970],
          [-0.2347,  0.5068,  0.1093, -0.1075],
          [ 0.5156, -0.2980,  0.1174, -0.5813],
          [-0.0958,  0.0225,  0.5258, -0.1198],
          [ 0.2671, -0.2377,  0.5364, -0.4530],
          [ 0.3074,  0.2341,  0.0942, -0.4606],
          [ 0.4801,  0.5533,  0.3059,  0.1502],
          [ 0.1991, -0.5952,  0.3031,  0.1864]], requires_grad=True)),
 ('self_attn.out_proj.weight',
  Parameter containing:
  tensor([[-0.3476, -0.0949, -0.1037,  0.0528],
          [-0.3409, -0.4293, -0.1437, -0.4244],
          [-0.4102,  0.4311,  0.2849,  0.3513],
          [ 0.0233,  0.2812,  0.0793,  0.1873]], requires_grad=True)),
 ('linear.weight',
  Parameter containing:
  tensor([[ 0.1073,  0.0315, -0.46

In [168]:
MergedLinearFormer.extract_attn(layer0)(x)

tensor([[[ 0.0518, -0.0260,  0.0318,  0.0421],
         [ 0.0749, -0.0587,  0.0354,  0.0845],
         [ 0.0768, -0.0613,  0.0359,  0.0879]],

        [[ 0.1126, -0.0826,  0.0511,  0.1229],
         [ 0.1121, -0.0826,  0.0499,  0.1222],
         [ 0.1089, -0.0753,  0.0505,  0.1145]]], grad_fn=<ViewBackward0>)

In [188]:
layer0(x)

tensor([[[ 0.0518, -0.0260,  0.0318,  0.0421],
         [ 0.0749, -0.0587,  0.0354,  0.0845],
         [ 0.0768, -0.0613,  0.0359,  0.0879]],

        [[ 0.1126, -0.0826,  0.0511,  0.1229],
         [ 0.1121, -0.0826,  0.0499,  0.1222],
         [ 0.1089, -0.0753,  0.0505,  0.1145]]], grad_fn=<UnsafeViewBackward0>)

In [225]:
tutorial = multihead_attn_eg.MultiHeadAttention(heads=1, d_model=4, dropout_prob=0, bias=False)
tutorial.query.linear.weight.data = layer0.self_attn.in_proj_weight.data[:dim, :].T
tutorial.key.linear.weight.data = layer0.self_attn.in_proj_weight.data[dim:(2*dim), :].T
tutorial.value.linear.weight.data = layer0.self_attn.in_proj_weight.data[(2*dim):, :].T
tutorial.output.weight.data = layer0.self_attn.out_proj.weight.data.T @ layer0.linear.weight.T
tutorial.output.bias.data = torch.zeros(4)
tutorial(query=x.swapaxes(0, 1), key=x.swapaxes(0, 1), value=x.swapaxes(0, 1)).swapaxes(0,1)

tensor([[[ 0.0309, -0.1271, -0.0719, -0.1084],
         [ 0.0206, -0.1124, -0.0644, -0.0963],
         [ 0.0117, -0.0996, -0.0578, -0.0858]],

        [[-0.0520, -0.0656, -0.0393, -0.0598],
         [-0.0525, -0.0643, -0.0384, -0.0590],
         [-0.0524, -0.0667, -0.0395, -0.0608]]], grad_fn=<TransposeBackward0>)

In [170]:
dl = torch.utils.data.DataLoader(FarthestPointDataset(3, 2), batch_size=2, num_workers=0)
batch = next(iter(dl))
print("Input:")
print(batch[0])
print("Label:")
print(batch[1])
print("Prediction:")
print(LinearFormer.farthest(2)(batch[0]))

Input:
tensor([[[-0.7552, -0.6555],
         [ 0.6431, -0.7658],
         [-0.1499, -0.9887]],

        [[ 0.5727, -0.8197],
         [-0.4568,  0.8895],
         [-0.9479,  0.3185]]])
Label:
tensor([[[ 0.6431, -0.7658],
         [-0.7552, -0.6555],
         [ 0.6431, -0.7658]],

        [[-0.4568,  0.8895],
         [ 0.5727, -0.8197],
         [ 0.5727, -0.8197]]])
Prediction:
tensor([[[ 0.6431, -0.7658],
         [-0.7552, -0.6555],
         [ 0.6431, -0.7658]],

        [[-0.4568,  0.8895],
         [ 0.5727, -0.8197],
         [ 0.5727, -0.8197]]], grad_fn=<UnsafeViewBackward0>)


In [54]:
import multihead_attn_eg

In [85]:
tutorial = multihead_attn_eg.MultiHeadAttention(heads=1, d_model=4, dropout_prob=0, bias=False)
# tutorial.query.linear.weight.data = 10000 * torch.Tensor([
#     [0,0,0,1],
#     [0,0,0,0],
#     [0,0,0,0],
#     [0,0,0,0]
# ]).T
# tutorial.key.linear.weight.data = torch.eye(4)
tutorial.query.linear.weight.data = 10000 * torch.Tensor([
    [1,1,1,1],
    [0,0,0,0],
    [0,0,0,0],
    [0,0,0,0]
]).T
tutorial.key.linear.weight.data = torch.Tensor([
    [0,0,0,0],
    [0,0,0,0],
    [0,0,0,0],
    [1,1,1,1]
]).T
tutorial.value.linear.weight.data = torch.eye(4)
tutorial.output.weight.data = torch.eye(4)
tutorial.output.bias.data = torch.zeros(4)

In [88]:
inn[:, 0, :] @ self.query.linear.weight.data.T @ self.key.linear.weight.data @ inn[:, 0, :].T


In [87]:
tutorial(query=x.swapaxes(0, 1), key=x.swapaxes(0, 1), value=x.swapaxes(0, 1)).swapaxes(0,1)

tensor([[[-0.7544,  0.3092, -0.5718,  0.0915],
         [-0.7544,  0.3092, -0.5718,  0.0915],
         [-0.4610, -0.3171,  0.7305,  0.3915]],

        [[-0.1032, -0.9727,  0.0969, -0.1838],
         [-0.1032, -0.9727,  0.0969, -0.1838],
         [-0.1032, -0.9727,  0.0969, -0.1838]]], grad_fn=<TransposeBackward0>)

In [91]:
merged(x)

tensor([[[-0.7544,  0.3092, -0.5718,  0.0915],
         [-0.7544,  0.3092, -0.5718,  0.0915],
         [-0.4610, -0.3171,  0.7305,  0.3915]],

        [[-0.1032, -0.9727,  0.0969, -0.1838],
         [-0.1032, -0.9727,  0.0969, -0.1838],
         [-0.1032, -0.9727,  0.0969, -0.1838]]], grad_fn=<ViewBackward0>)

In [92]:
layer0.self_attn(x, x, x, need_weights=False)[0]

tensor([[[-0.7544,  0.3092, -0.5718,  0.0915],
         [-0.4610, -0.3171,  0.7305,  0.3915],
         [-0.4610, -0.3171,  0.7305,  0.3915]],

        [[-0.1032, -0.9727,  0.0969, -0.1838],
         [-0.1032, -0.9727,  0.0969, -0.1838],
         [-0.1032, -0.9727,  0.0969, -0.1838]]], grad_fn=<TransposeBackward0>)

In [71]:
x

tensor([[[-0.7544,  0.3092, -0.5718,  0.0915],
         [-0.4610, -0.3171,  0.7305,  0.3915],
         [ 0.1082, -0.6548,  0.6668,  0.3390]],

        [[-0.1032, -0.9727,  0.0969, -0.1838],
         [-0.3567, -0.5130, -0.2228,  0.7483],
         [-0.9331,  0.1338,  0.2648,  0.2032]]])