# Imports

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%load_ext watermark

In [2]:
import numpy as np
import torch
from torch import nn

In [3]:
%watermark -i -iv

numpy: 1.21.2
torch: 1.10.0



# Introduction


## Setup
From ["Lightweight and Efficient End-to-End Speech Recognition Using Low-Rank Transformer"](https://arxiv.org/abs/1910.13923):

"The design is based on matrix factorization by approximating the matrix $\mathbf{W} \in \mathbb{R}^{m \times n}$ in the linear feed-forward unit using two smaller matrices, $\mathbf{E} \in \mathbb{R}^{m \times r}$ and $\mathbf{D} \in \mathbb{R}^{r \times n}$:

$\mathbf{W} \approx \mathbf{E} \times \mathbf{D}$

The matrix $\mathbf{W}$ requires $m n$ parameters and $m n$ flops, while $\mathbf{E}$ and $\mathbf{D}$ require $r m+r n=r(m+n)$ parameters and $r(m+n)$ flops.

If we take the rank to be very low $r<<m, n$, the number of parameters and flops in $\mathbf{E}$ and $\mathbf{D}$ are much smaller compared to $\mathbf{W}$."

## Applications
This setup is interesting because it could potentially reduce the memory for Transformer setups using short-sequence lengths, as outlined in  ["Greenformers - Improving Computation and Memory Efficiency in Transformer Models via Low-Rank Approximation"](https://arxiv.org/abs/2108.10808):

"The Low-Rank Transformer model is suitable for improving both the time and memory efficiency in processing short-sequence (≤ 512) input data, while the Linformer model is suitable for improving the efficiency in processing long-sequence input data (> 512)."

**$\rightarrow$ This setup can be very interesting for Transformer setups that usually have sequence lengths below 768, e.g., ViT!**

In addition, it does not suffer from the two deficiencies of the Linformer setup (see https://github.com/lucidrains/linformer#linformer-for-pytorch) as it works for the auto-regressive case and assumes no fixed sequence length.

## Sources
The outlined "linear factorization" setup is inspired by:
1. https://discuss.pytorch.org/t/factorization-of-a-weight-matrix-as-products-of-low-rank-matrices/76278
1. ["Lightweight and Efficient End-to-End Speech Recognition Using Low-Rank Transformer"](https://arxiv.org/abs/1910.13923)
1. ["Greenformers - Improving Computation and Memory Efficiency in Transformer Models via Low-Rank Approximation"](https://arxiv.org/abs/2108.10808)
1. https://github.com/lucidrains/linformer

# Code

Code setup inspired by https://discuss.pytorch.org/t/how-to-replace-all-relu-activations-in-a-pretrained-network/31591/7.

In [4]:
def nn_linear_factorization(model, rank=64):
    """
    Recursively replace nn.Linear(in_features=a, out_features=b) with
    nn.Sequential(nn.Linear(a, rank),
                  nn.Linear(rank, b)))
    to replace the big linear layer as a factorization consisting of
    a sequence of two smaller linear layers.
    
    The lower the rank hyperparameter is, the higher the memory savings can be.
    """
    for name, child in model.named_children():
        if isinstance(child, nn.Linear):
            setattr(model, name, nn.Sequential(nn.Linear(child.in_features, rank),
                                               nn.Linear(rank, child.out_features)))
        else:
            nn_linear_factorization(child)

# Test

In [5]:
def get_parameter_count(model):
    return np.sum([np.prod(p.shape) for p in model.parameters()])

## Basic

In [6]:
def get_simple_model():
    return nn.Sequential(nn.Linear(768, 256), nn.Linear(256, 128))

In [7]:
model = get_simple_model()

In [8]:
model

Sequential(
  (0): Linear(in_features=768, out_features=256, bias=True)
  (1): Linear(in_features=256, out_features=128, bias=True)
)

In [9]:
params_orig = get_parameter_count(model); params_orig

229760

In [10]:
nn_linear_factorization(model, 64)

In [11]:
model

Sequential(
  (0): Sequential(
    (0): Linear(in_features=768, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=256, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=128, bias=True)
  )
)

In [12]:
params_factorized = get_parameter_count(model); params_factorized

90624

In [13]:
params_orig, params_factorized, round(params_factorized/params_orig, 2)

(229760, 90624, 0.39)

## Transformer

Based on: https://github.com/lucidrains/x-transformers#usage

In [14]:
from x_transformers import XTransformer

In [15]:
model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    tie_token_emb = True      # tie embeddings of encoder and decoder
)

In [16]:
model

XTransformer(
  (encoder): TransformerWrapper(
    (token_emb): Embedding(256, 512)
    (pos_emb): AbsolutePositionalEmbedding(
      (emb): Embedding(1024, 512)
    )
    (emb_dropout): Dropout(p=0, inplace=False)
    (project_emb): Identity()
    (attn_layers): Encoder(
      (layers): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1): Attention(
            (to_q): Linear(in_features=512, out_features=512, bias=False)
            (to_k): Linear(in_features=512, out_features=512, bias=False)
            (to_v): Linear(in_features=512, out_features=512, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
            (to_out): Linear(in_features=512, out_features=512, bias=True)
          )
          (2): Residual()
        )
        (1): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1): FeedForward(
            (net): Sequential(
              (0): Sequentia

In [17]:
params_orig = get_parameter_count(model); params_orig

45555200

In [18]:
src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))
tgt_mask = torch.ones_like(tgt).bool()

In [19]:
%timeit -n 3 loss = model(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask) # (1, 1024, 512)

4.92 s ± 1.25 s per loop (mean ± std. dev. of 7 runs, 3 loops each)


In [21]:
nn_linear_factorization(model, 64)

In [22]:
model

XTransformer(
  (encoder): TransformerWrapper(
    (token_emb): Embedding(256, 512)
    (pos_emb): AbsolutePositionalEmbedding(
      (emb): Embedding(1024, 512)
    )
    (emb_dropout): Dropout(p=0, inplace=False)
    (project_emb): Identity()
    (attn_layers): Encoder(
      (layers): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1): Attention(
            (to_q): Sequential(
              (0): Linear(in_features=512, out_features=64, bias=True)
              (1): Linear(in_features=64, out_features=512, bias=True)
            )
            (to_k): Sequential(
              (0): Linear(in_features=512, out_features=64, bias=True)
              (1): Linear(in_features=64, out_features=512, bias=True)
            )
            (to_v): Sequential(
              (0): Linear(in_features=512, out_features=64, bias=True)
              (1): Linear(in_features=64, out_features=512, bias=True)
            )
            (dr

In [23]:
params_factorized = get_parameter_count(model); params_factorized

10035840

In [24]:
params_orig, params_factorized, round(params_factorized/params_orig, 2)

(45555200, 10035840, 0.22)

In [25]:
%timeit -n 3 loss = model(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask) # (1, 1024, 512)

1.89 s ± 194 ms per loop (mean ± std. dev. of 7 runs, 3 loops each)


**$\rightarrow$ This setup saves ~70% of the parameters and is ~40% faster on CPU with this simple setup.**

# Outlook

1. Test the setup for standard Transformer setups!
1. Combine this with reversible layers to decrease memory even more?
1. Can the `rank` hyperparameter be learned, similar to ["Adaptive Attention Span in Transformers"](https://arxiv.org/abs/1905.07799) 
https://github.com/facebookresearch/adaptive-span/blob/main/adaptive_span.py)?
1. Implement post-training low rank transformation with https://geotorch.readthedocs.io/en/latest/lowrank/lowrank.html?