In [1]:
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, Linear, LayerNorm, MultiheadAttention
import lightning.pytorch as pl

In [162]:
# class MyMultiheadAttention(Module):
#     # start from here: https://nn.labml.ai/transformers/mha.html#section-5

class LinearformerLayer(Module):
    # TODO: extremely frustrating but they require the dimension of the token representations to be divisible by the number of heads
    def __init__(self, d_model, nhead,
                 batch_first=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, bias=False, batch_first=batch_first, **factory_kwargs)
        self.linear = Linear(d_model, d_model, bias=False, **factory_kwargs)

    def forward(self, x):
        # N.B.: This assumes that all sequences in the batch have the same length
        # if they are not, we will need to use `key_padding_mask`
        x = self.self_attn(x, x, x, need_weights=False)[0]
        x = self.linear(x)
        return x

class LinearFormer(Module):
    def __init__(self, num_layers, d_model, nhead, 
                 batch_first=True, device=None, dtype=None):
        super().__init__()
        self.layers = torch.nn.Sequential(*(LinearformerLayer(
            d_model, nhead, batch_first=batch_first, device=device, dtype=dtype) for _ in range(num_layers)))
        # self.layers = ModuleList([LinearformerLayer(
        #     d_model, nhead, batch_first=batch_first, device=device, dtype=dtype) for _ in range(num_layers)])

    def forward(self, x):
        return self.layers(x)

    @classmethod
    def farthest(cls, d_model, batch_first=True, device=None, dtype=None):
        model = cls(1, d_model, 1, batch_first=batch_first, device=device, dtype=dtype)
        model.layers[0].linear.weight.data = torch.eye(*model.layers[0].linear.weight.data.shape)
        if model.layers[0].linear.bias is not None:
            model.layers[0].linear.bias.data = torch.zeros_like(model.layers[0].linear.bias)
        model.layers[0].self_attn.in_proj_weight.data = torch.concat([torch.eye(d_model), -1000 * torch.eye(d_model), torch.eye(d_model)])
        if model.layers[0].self_attn.in_proj_bias is not None:
            model.layers[0].self_attn.in_proj_bias.data = torch.zeros_like(model.layers[0].self_attn.in_proj_bias)
        model.layers[0].self_attn.out_proj.weight.data = torch.eye(*model.layers[0].self_attn.out_proj.weight.shape)
        if model.layers[0].self_attn.out_proj.bias:
            model.layers[0].self_attn.out_proj.bias.data = torch.zeros_like(model.layers[0].self_attn.out_proj.bias)
        return model

    @classmethod
    def farthest_perturbed(cls, d_model, batch_first=True, device=None, dtype=None):
        scale = .01
        model = cls(1, d_model, 1, batch_first=batch_first, device=device, dtype=dtype)
        model.layers[0].linear.weight.data = torch.eye(*model.layers[0].linear.weight.data.shape) + scale * torch.randn_like(model.layers[0].linear.weight.data)
        if model.layers[0].linear.bias is not None:
            model.layers[0].linear.bias.data = torch.zeros_like(model.layers[0].linear.bias) + scale * torch.randn_like(model.layers[0].linear.bias)
        model.layers[0].self_attn.in_proj_weight.data = torch.concat([torch.eye(d_model), -1000 * torch.eye(d_model), torch.eye(d_model)]) + scale * torch.randn_like(model.layers[0].self_attn.in_proj_weight.data)
        if model.layers[0].self_attn.in_proj_bias is not None:
            model.layers[0].self_attn.in_proj_bias.data = torch.zeros_like(model.layers[0].self_attn.in_proj_bias) + scale * torch.randn_like(model.layers[0].self_attn.in_proj_bias.data)
        model.layers[0].self_attn.out_proj.weight.data = torch.eye(*model.layers[0].self_attn.out_proj.weight.shape) + scale * torch.randn_like(model.layers[0].self_attn.out_proj.weight.data)
        if model.layers[0].self_attn.out_proj.bias:
            model.layers[0].self_attn.out_proj.bias.data = torch.zeros_like(model.layers[0].self_attn.out_proj.bias) + scale * torch.randn_like(model.layers[0].self_attn.out_proj.bias.data)
        return model

In [153]:
def gen_sentence(ntokens, dim, rng=None):
    x = torch.randn((ntokens, dim), generator=rng)
    # normalize each row
    return x / x.norm(dim=1).reshape(-1, 1)

def label_farthest(sentence):
    distances = 2 - 2 * sentence @ sentence.T
    farthests = distances.argmax(dim=0)
    return sentence[farthests, :]

In [154]:
class FarthestPointDataset(torch.utils.data.IterableDataset):
    # TODO: let ntokens vary according to some reasonable distribution
    def __init__(self, ntokens, dim, seed=None):
        super().__init__()
        self.ntokens = ntokens
        self.dim = dim
        self.seed = seed

    def __iter__(self):
        rng = torch.Generator()
        worker_info = torch.utils.data.get_worker_info()
        worker_id = 0 if worker_info is None else worker_info.id
        seed = rng.seed() if self.seed is None else self.seed
        rng.manual_seed(seed + worker_id)

        while True:
            sentence = gen_sentence(self.ntokens, self.dim, rng=rng)
            yield sentence, label_farthest(sentence)



In [155]:
class LitSequenceRegression(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        y_hat = self.model(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        # return optimizer
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9, verbose=False)
        return [optimizer], [lr_scheduler]

In [165]:
ntokens = 10
dim = 4
num_layers = 1
nhead = 1
batch_size = 64
# lit_model = LitSequenceRegression(LinearFormer(num_layers, dim, nhead))
lit_model = LitSequenceRegression(LinearFormer.farthest_perturbed(dim))
data = torch.utils.data.DataLoader(FarthestPointDataset(ntokens, dim), batch_size=batch_size, num_workers=0)

In [166]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=200)
trainer.fit(model=lit_model, train_dataloaders=data)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type         | Params
---------------------------------------
0 | model | LinearFormer | 80    
---------------------------------------
80        Trainable params
0         Non-trainable params
80        Total params
0.000     Total estimated model params size (MB)


Epoch 32:  93%|█████████▎| 93/100 [00:00<00:00, 111.39it/s, v_num=21, train_loss=0.000607] 

In [120]:
x = gen_sentence(ntokens, dim)
y = label_farthest(x)
yhat = lit_model.model(x)

In [144]:
dl = torch.utils.data.DataLoader(FarthestPointDataset(3, dim), batch_size=1000, num_workers=0)
x, y = next(iter(dl))
yhat = lit_model.model(x)
torch.nn.functional.mse_loss(y, yhat)

tensor(0.0042, grad_fn=<MseLossBackward0>)

tensor(0.0020, grad_fn=<MseLossBackward0>)

In [92]:
# print(x)
# print(y)
# print(yhat)

In [83]:
layer0 = lit_model.model.layers[0]
# layer0 = LinearFormer.farthest(dim).layers[0]
list(layer0.named_parameters())

[('self_attn.in_proj_weight',
  Parameter containing:
  tensor([[-2.1683, -3.5842,  3.3000,  4.0539],
          [ 1.3107, -3.1545,  3.1496, -4.6916],
          [-4.1139, -3.4144, -3.9734, -1.6968],
          [-4.5240,  3.5688,  2.6734, -1.8239],
          [ 2.1421,  3.3045, -3.3916, -3.9073],
          [-1.3476,  3.0847, -3.2302,  4.6493],
          [ 3.8582,  2.9305,  3.7388,  1.6091],
          [ 4.6814, -3.3468, -2.6973,  1.7620],
          [-1.0091, -0.3342, -0.1035, -0.1046],
          [ 0.0347,  0.8540,  0.6670, -0.7823],
          [ 0.2932, -0.8154,  0.4150, -0.5697],
          [-0.1577,  0.1001,  1.0149,  0.3995]], requires_grad=True)),
 ('self_attn.out_proj.weight',
  Parameter containing:
  tensor([[-0.1892, -0.2359, -0.6602,  0.9116],
          [ 0.8456,  0.3112, -0.2822,  0.0273],
          [-0.2101,  0.8255,  0.0537,  0.3716],
          [ 0.3103, -0.1808,  0.6831,  0.4201]], requires_grad=True)),
 ('linear.weight',
  Parameter containing:
  tensor([[-0.0961, -1.0117,  0.17

In [169]:
layer0.self_attn.in_proj_weight[-dim:,:]

tensor([[-1.0091, -0.3342, -0.1035, -0.1046],
        [ 0.0347,  0.8540,  0.6670, -0.7823],
        [ 0.2932, -0.8154,  0.4150, -0.5697],
        [-0.1577,  0.1001,  1.0149,  0.3995]], grad_fn=<SliceBackward0>)

In [167]:
layer0.self_attn.in_proj_weight[-dim:,:] @ layer0.self_attn.out_proj.weight @ layer0.linear.weight

tensor([[-0.5772,  0.2941,  0.7836,  0.4064],
        [-0.1744, -0.0619,  0.0508, -1.3302],
        [ 0.1342,  1.0381, -0.3636, -0.1514],
        [ 0.5791,  0.0262,  0.3389, -0.3902]], grad_fn=<MmBackward0>)

In [85]:
layer0.self_attn.in_proj_weight[:dim,:] @ layer0.self_attn.in_proj_weight[dim:(2*dim),:]

tensor([[ 31.8949, -22.1183,  20.3345,   4.2615],
        [ -2.7532,  19.5320,  30.1750, -22.9863],
        [-27.4843, -30.0926,  14.7030,  -9.1836],
        [-12.7243,   9.9977,  18.7303,  35.3572]], grad_fn=<MmBackward0>)

In [76]:
layer0.self_attn.in_proj_weight[4:8,:] @ layer0.self_attn.in_proj_weight[:4,:]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2 and 4x2)

In [65]:
dl = torch.utils.data.DataLoader(FarthestPointDataset(3, 2), batch_size=2, num_workers=0)
batch = next(iter(dl))
print("Input:")
print(batch[0])
print("Label:")
print(batch[1])
print("Prediction:")
print(LinearFormer.farthest(2)(batch[0]))

Input:
tensor([[[-0.5255, -0.8508],
         [ 0.7689, -0.6394],
         [ 0.1725,  0.9850]],

        [[ 0.1269, -0.9919],
         [ 0.0750,  0.9972],
         [-0.8024, -0.5968]]])
Label:
tensor([[[ 0.1725,  0.9850],
         [ 0.1725,  0.9850],
         [-0.5255, -0.8508]],

        [[ 0.0750,  0.9972],
         [ 0.1269, -0.9919],
         [ 0.0750,  0.9972]]])
Prediction:
tensor([[[ 0.1725,  0.9850],
         [ 0.1725,  0.9850],
         [-0.5255, -0.8508]],

        [[ 0.0750,  0.9972],
         [ 0.1269, -0.9919],
         [ 0.0750,  0.9972]]], grad_fn=<UnsafeViewBackward0>)
