In [1]:
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, Linear, LayerNorm, MultiheadAttention
import lightning.pytorch as pl

In [17]:
# class MyMultiheadAttention(Module):
#     # start from here: https://nn.labml.ai/transformers/mha.html#section-5

class LinearformerLayer(Module):
    # TODO: extremely frustrating but they require the dimension of the token representations to be divisible by the number of heads
    def __init__(self, d_model, nhead,
                 batch_first=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, batch_first=batch_first, **factory_kwargs)
        self.linear = Linear(d_model, d_model, **factory_kwargs)

    def forward(self, x):
        # N.B.: This assumes that all sequences in the batch have the same length
        # if they are not, we will need to use `key_padding_mask`
        x = self.self_attn(x, x, x, need_weights=False)[0]
        x = self.linear(x)
        return x

class LinearFormer(Module):
    def __init__(self, num_layers, d_model, nhead, 
                 batch_first=True, device=None, dtype=None):
        super().__init__()
        self.layers = torch.nn.Sequential(*(LinearformerLayer(
            d_model, nhead, batch_first=batch_first, device=device, dtype=dtype) for _ in range(num_layers)))
        # self.layers = ModuleList([LinearformerLayer(
        #     d_model, nhead, batch_first=batch_first, device=device, dtype=dtype) for _ in range(num_layers)])

    def forward(self, x):
        return self.layers(x)

    @classmethod
    def farthest(cls, d_model, batch_first=True, device=None, dtype=None):
        model = cls(1, d_model, 1, batch_first=batch_first, device=device, dtype=dtype)
        model.layers[0].linear.weight.data = torch.eye(*model.layers[0].linear.weight.data.shape)
        model.layers[0].linear.bias.data = torch.zeros_like(model.layers[0].linear.bias)
        model.layers[0].self_attn.in_proj_weight.data = torch.concat([torch.eye(2), -1000 * torch.eye(2), torch.eye(2)])
        model.layers[0].self_attn.in_proj_bias.data = torch.zeros_like(model.layers[0].self_attn.in_proj_bias)
        model.layers[0].self_attn.out_proj.weight.data = torch.eye(*model.layers[0].self_attn.out_proj.weight.shape)
        model.layers[0].self_attn.out_proj.bias.data = torch.zeros_like(model.layers[0].self_attn.out_proj.bias)
        return model


In [3]:
def gen_sentence(ntokens, dim, rng=None):
    x = torch.randn((ntokens, dim), generator=rng)
    # normalize each row
    return x / x.norm(dim=1).reshape(-1, 1)

def label_farthest(sentence):
    distances = 2 - 2 * sentence @ sentence.T
    farthests = distances.argmax(dim=0)
    return sentence[farthests, :]

In [4]:
class FarthestPointDataset(torch.utils.data.IterableDataset):
    # TODO: let ntokens vary according to some reasonable distribution
    def __init__(self, ntokens, dim, seed=None):
        super().__init__()
        self.ntokens = ntokens
        self.dim = dim
        self.seed = seed

    def __iter__(self):
        rng = torch.Generator()
        worker_info = torch.utils.data.get_worker_info()
        worker_id = 0 if worker_info is None else worker_info.id
        seed = rng.seed() if self.seed is None else self.seed
        rng.manual_seed(seed + worker_id)

        while True:
            sentence = gen_sentence(self.ntokens, self.dim, rng=rng)
            yield sentence, label_farthest(sentence)



In [5]:
class LitSequenceRegression(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        y_hat = self.model(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.95, verbose=False)
        # return optimizer
        return [optimizer], [lr_scheduler]

In [6]:
ntokens = 5
dim = 4
num_layers = 1
nhead = 1
batch_size = 64
lit_model = LitSequenceRegression(LinearFormer(num_layers, dim, nhead))
data = torch.utils.data.DataLoader(FarthestPointDataset(ntokens, dim), batch_size=batch_size, num_workers=0)

In [None]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=100)
trainer.fit(model=lit_model, train_dataloaders=data)

In [None]:
x = gen_sentence(ntokens, dim)
y = label_farthest(x)
yhat = lit_model.model(x)

In [None]:
print(x)
print(y)
print(yhat)

In [21]:
dl = torch.utils.data.DataLoader(FarthestPointDataset(3, 2), batch_size=2, num_workers=0)
batch = next(iter(dl))
print(batch)
print(LinearFormer.farthest(2)(batch[0]))

[tensor([[[-0.5491, -0.8358],
         [-0.9887, -0.1501],
         [-0.4044, -0.9146]],

        [[ 0.7892,  0.6141],
         [-0.2937,  0.9559],
         [-0.9933,  0.1159]]]), tensor([[[-0.9887, -0.1501],
         [-0.4044, -0.9146],
         [-0.9887, -0.1501]],

        [[-0.9933,  0.1159],
         [ 0.7892,  0.6141],
         [ 0.7892,  0.6141]]])]
tensor([[[-0.9887, -0.1501],
         [-0.4044, -0.9146],
         [-0.9887, -0.1501]],

        [[-0.9933,  0.1159],
         [ 0.7892,  0.6141],
         [ 0.7892,  0.6141]]], grad_fn=<AddBackward0>)
