# Analysis of the memory consumption of the Transformer given the maximum sequence length of 1024

This notebook looks at how high we can make the batch size if all the other hyperparameters that influence
the memory consumption are fixed to a value.

Most of the default values are taken from AIAYN with a lower value for d_model.
Please look at the report for more details.

~ 6.2 GiB GPU Memory

## Load Transformer model and initialize Hyperparameters with little to no impact on the memory consumption

In [10]:
from src.models.transformer import make_time_series_model, make_encoder_to_decoder_time_series_model, make_encoder_only_time_series_model
import numpy as np
import torch
import torch.nn as nn

# Hyperparameters with no significant influence on the memory consumption
epochs = 3
# Model Parameters
dropout = 0.2  # Dropout rate
debug = False # only 1 batch per epoch
d_input = 1 # From dataset
d_output = 1  # From dataset

# Loss function for gradient descent, possible choices are SL1, MSE
loss_function_ = "SL1"
device = torch.device("cuda:0")

## Define some helper functions to try and run the model with increasingly higher sizes

In [11]:
def test_model_memory_consumption(model, d_model, N, h, d_ff, sequence_len, batch_size):
    """ Test whether the Transformer model fits on the current system regarding its memory consumption.

    :param model: (str) {'tst': normal transformer, 'encoder_only_tst': encoder only architecture, 'enc2dec_tst': encoder2decoder transformer
    :param d_model: (int) model dimensionality
    :param N: (int) Number of layers
    :param h: (int) Number of heads in a single layer
    :param d_ff: (int) Dimensionality of the positionwise feedforward network
    :param sequence_len: (int) length of the input and output seqeuence
    :param batch_size: (int) batch size to be tested
    :return:
    """
    global d_input, d_output, dropout
    train_data = torch.rand((batch_size, sequence_len, d_input))
    train_data_out = torch.zeros_like(train_data)
    y_in = torch.zeros_like(train_data)
    x_mask = torch.ones(1, 1, sequence_len).to(device)
    y_mask = torch.ones_like(x_mask)
    loss_function = nn.SmoothL1Loss()

    if model == "tst":
        model = make_time_series_model(d_input=d_input,
                                   d_output=d_output,
                                   N=N,
                                   d_model=d_model,
                                   d_ff=d_ff,
                                   h=h,
                                   dropout=dropout,
                                   device=device)
        forward_pass_args = (train_data.to(device), y_in.to(device), x_mask, y_mask)
    elif model == "encoder_only_tst":
        model = make_encoder_only_time_series_model(d_input=d_input,
                                   d_output=d_output,
                                   N=N,
                                   d_model=d_model,
                                   d_ff=d_ff,
                                   h=h,
                                   dropout=dropout,
                                   device=device)
        forward_pass_args = (train_data.to(device), x_mask)
    elif model == "enc2dec_tst":
        model = make_encoder_to_decoder_time_series_model(d_input=d_input,
                                   d_output=d_output,
                                   N=N,
                                   d_model=d_model,
                                   d_ff=d_ff,
                                   h=h,
                                   dropout=dropout,
                                   device=device)
        forward_pass_args = (train_data.to(device), x_mask)

    optimizer = torch.optim.Adam(model.parameters())
    # Run two epochs to see if a Runtime Error is thrown
    for idx_epoch in range(epochs):
        running_loss = 0
        optimizer.zero_grad()

        y_out_pred = model(*forward_pass_args)

        loss = loss_function(train_data_out.to(device), y_out_pred)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()


def find_maximum_batch_size(d_model, N, h, d_ff, sequence_len):
    """ Given some parameters about the Transformer model that have a (presumably) high influence on its memeory consumption
    empirically determine the largest possible batch size (as a power of 2) that still runs on a single GPU
    (~ 6GiB of used memory). The test is done for all types of Transformer architectures (normal, decoder_only, encoder2decoder)

    :param d_model: (int) model dimensionality
    :param N: (int) layer depth
    :param h: (int) Number of attention heads
    :param d_ff: (int) pointwise feed forward network
    :param sequence_len: (int) input and output sequence length
    :return: list of maximum batch sizes for the three architectures in the order ['normal', 'decoder_only', 'encoder2decoder']
    """
    models = ["tst", "encoder_only_tst", "enc2dec_tst"]
    for model in models:
        print(f"Model: {model} - Finding maximum batch size for model with d_model: {d_model}, N:{N}, d_ff:{d_ff}, sequence_len:{sequence_len}")
        out_of_memory = False
        batch_size = 1
        while not out_of_memory:
            try:
                test_model_memory_consumption(model=model, d_model=d_model, N=N, h=h, d_ff=d_ff, sequence_len=sequence_len, batch_size=batch_size)
            except RuntimeError:
                out_of_memory = True
                batch_size = batch_size // 2 # last successful run
            else:
                batch_size *= 2

        print(f"Maximum Batch Size: {batch_size}")
        print(f"Estimated maximum batch size if we use Log-Sparse:{int(batch_size * sequence_len / (np.log(sequence_len)**2))}")
        print("-"*50)
    print("Done")

## Define Hyperparameters that have a high influence on the memory consumption and calculate how large we could make the Batch size

In [12]:
d_model = 64 # Latent dim
N = 8  # Number of encoder and decoder to stack
h = 8  # Number of heads
d_ff = 128
sequence_len = 256

find_maximum_batch_size(d_model=d_model, N=N, h=h, d_ff=d_ff, sequence_len=sequence_len)

Model: tst - Finding maximum batch size for model with d_model: 64, N:8, d_ff:128, sequence_len:256
Maximum Batch Size: 32
Estimated maximum batch size if we use Log-Sparse:266
--------------------------------------------------
Model: encoder_only_tst - Finding maximum batch size for model with d_model: 64, N:8, d_ff:128, sequence_len:256
Maximum Batch Size: 64
Estimated maximum batch size if we use Log-Sparse:532
--------------------------------------------------
Model: enc2dec_tst - Finding maximum batch size for model with d_model: 64, N:8, d_ff:128, sequence_len:256
Maximum Batch Size: 32
Estimated maximum batch size if we use Log-Sparse:266
--------------------------------------------------
Done


Model dimensionality has less impact than what I expected. The reason for this is described in the report.

The impact of N, h, the type of architecture and sequence_len is significant

Interesting setting:

d_model = 64 # Latent dim
N = 6  # Number of encoder and decoder to stack
h = 8  # Number of heads
d_ff = 1024
sequence_len = 1024

Works for batch_size = 8 of the decoder only Transformer
Doesn't scale