In [10]:
import os
import json

In [11]:
base_path = '/NS/llm-1/nobackup/vnanda/llm_base_models/'

In [12]:
models = ['pythia-70m', 'pythia-1b', 'pythia-6.9b', 'pythia-12b']

In [13]:
def get_details_from_config(config):
    with open(config, 'r') as f:
        config = json.load(f)
    layers = config['num_hidden_layers']
    attention_heads = config['num_attention_heads']
    hidden_size = config['hidden_size']
    return {
        'layers': layers,
        'attention_heads': attention_heads,
        'hidden_size': hidden_size
    }

In [14]:
# check  'config.json' file for each model
model_details = {}

for model in models:
    model_path = os.path.join(base_path, model)
    config_file = os.path.join(model_path, 'config.json')
    model_details[model] = get_details_from_config(config_file)

In [15]:
def no_parallelism_memory_approximation(layers, attention_heads, hidden_size, batch_size, sequence_length):
    per_layer_memory_in_bytes = sequence_length * batch_size * hidden_size * (34 + (5 * attention_heads * sequence_length)/hidden_size)
    total_memory_in_bytes = layers * per_layer_memory_in_bytes
    total_memory_in_gigabytes = total_memory_in_bytes / (1024 ** 3)
    return total_memory_in_gigabytes

In [20]:
no_parallelism_memory_approximation(32, 32, 4096, 1, 1000)

8.91876220703125

In [16]:
BATCH_SIZES = [1,2,4,6,8]
CONTEXT_LENGTHS = [1,2,4,6,8,16,32,64,128,256,512,1024,2048]

In [21]:
memory_approximations = {}
for model in models:
    memory_approximations[model] = {}
    for batch_size in BATCH_SIZES:
        memory_approximations[model][batch_size] = {}
        for sequence_length in CONTEXT_LENGTHS:
            memory_approximations[model][batch_size][sequence_length] = no_parallelism_memory_approximation(
                model_details[model]['layers'],
                model_details[model]['attention_heads'],
                model_details[model]['hidden_size'],
                batch_size,
                sequence_length
            )
            print(f'{model} {batch_size} {sequence_length} {memory_approximations[model][batch_size][sequence_length]}')

pythia-70m 1 1 9.749829769134521e-05
pythia-70m 1 2 0.00019544363021850586
pythia-70m 1 4 0.00039267539978027344
pythia-70m 1 6 0.0005916953086853027
pythia-70m 1 8 0.0007925033569335938
pythia-70m 1 16 0.001613616943359375
pythia-70m 1 32 0.0033416748046875
pythia-70m 1 64 0.00714111328125
pythia-70m 1 128 0.01611328125
pythia-70m 1 256 0.03955078125
pythia-70m 1 512 0.1083984375
pythia-70m 1 1024 0.333984375
pythia-70m 1 2048 1.13671875
pythia-70m 2 1 0.00019499659538269043
pythia-70m 2 2 0.0003908872604370117
pythia-70m 2 4 0.0007853507995605469
pythia-70m 2 6 0.0011833906173706055
pythia-70m 2 8 0.0015850067138671875
pythia-70m 2 16 0.00322723388671875
pythia-70m 2 32 0.006683349609375
pythia-70m 2 64 0.0142822265625
pythia-70m 2 128 0.0322265625
pythia-70m 2 256 0.0791015625
pythia-70m 2 512 0.216796875
pythia-70m 2 1024 0.66796875
pythia-70m 2 2048 2.2734375
pythia-70m 4 1 0.00038999319076538086
pythia-70m 4 2 0.0007817745208740234
pythia-70m 4 4 0.0015707015991210938
pythia-70m 

In [19]:
memory_approximations

{'pythia-70m': {1: {1: 9.749829769134521e-05,
   2: 0.00019544363021850586,
   4: 0.00039267539978027344,
   6: 0.0005916953086853027,
   8: 0.0007925033569335938,
   16: 0.001613616943359375,
   32: 0.0033416748046875,
   64: 0.00714111328125,
   128: 0.01611328125,
   256: 0.03955078125,
   512: 0.1083984375,
   1024: 0.333984375,
   2048: 1.13671875},
  2: {1: 0.00019499659538269043,
   2: 0.0003908872604370117,
   4: 0.0007853507995605469,
   6: 0.0011833906173706055,
   8: 0.0015850067138671875,
   16: 0.00322723388671875,
   32: 0.006683349609375,
   64: 0.0142822265625,
   128: 0.0322265625,
   256: 0.0791015625,
   512: 0.216796875,
   1024: 0.66796875,
   2048: 2.2734375},
  4: {1: 0.00038999319076538086,
   2: 0.0007817745208740234,
   4: 0.0015707015991210938,
   6: 0.002366781234741211,
   8: 0.003170013427734375,
   16: 0.0064544677734375,
   32: 0.01336669921875,
   64: 0.028564453125,
   128: 0.064453125,
   256: 0.158203125,
   512: 0.43359375,
   1024: 1.3359375,
   20

In [18]:
# save to json

with open('Measurements/no_parallelism_memory_approximations.json', 'w') as f:
    json.dump(memory_approximations, f)