Requires you to have run the timing experiment `jobscripts/memory_efficiency/time_test_all.sh`

In [40]:
import os

logging_dir = '/work3/s184399/msc/logs/'
result_dirs = list(filter(lambda x: ('time_test' in x) and (len(os.listdir(os.path.join(logging_dir, x))) > 0), os.listdir(logging_dir)))
result_files = list([list(filter(lambda x: 'time_test' in x, os.listdir(os.path.join(logging_dir, d)))) for d in result_dirs])
result_files = list(map(lambda x: x[0], result_files))

In [41]:
import numpy as np

def bootstrap_CI(p, alpha=0.05, k=2000, n_decimals=0):
  """
    Computes the confidence interval of the mean using bootstrapping.
    Here the confidence interval is the 100*(1-alpha) central CI, from percentile 100*(alpha/2) to 100*(1-alpha/2) rounded to broadest interval when picking the indices.
    Line Clemmensen suggests picking k (number of repeats) to 1000 or 2000 for this tasks, so I do this.
  """
  assert isinstance(p, np.ndarray)
  assert p.ndim == 1
  N = len(p)
  bootstraps = np.random.choice(p, (k,N), replace=True)
  ci_lower = alpha/2.
  ci_upper = 1.-(alpha/2.)
  idxs = [
    int(np.floor(k*ci_lower)),
    int(np.ceil(k*ci_upper))
  ]
  CI = np.sort(np.mean(bootstraps, axis=-1))[idxs]     # Sorts lowest to highest
  assert CI[0] < CI[1]  # To be on the safe side...
  CI = [f"{c:.{n_decimals}f}s" for c in CI]
  return CI, N    # Returns CI and support (N)

In [42]:
import numpy as np
import matplotlib.pyplot as plt

for i in range(len(result_dirs)):
    times = np.loadtxt(os.path.join(logging_dir, result_dirs[i], result_files[i]), delimiter=',', skiprows=1)
    print(f"{result_files[i].replace('time_test_','').replace('.txt','')}: {bootstrap_CI(times[1:], n_decimals=2)[0]}")

Llama-2-7b-hf_MEBP_sgd: ['35.74s', '35.98s']
opt-125m_adamw_torch: ['2.63s', '2.64s']
opt-125m_sgd: ['1.08s', '1.09s']
Llama-2-7b-hf_sgd: ['38.59s', '39.44s']
Llama-2-7b-hf_QLoRA_adamw_torch: ['5.57s', '5.69s']
opt-125m_MEBP_adam: ['1.32s', '1.33s']
opt-125m_MEBP_sgd: ['1.17s', '1.19s']
opt-125m_QLoRA_adamw_torch: ['0.87s', '0.88s']


In [44]:
import numpy as np
import matplotlib.pyplot as plt

for i in range(len(result_dirs)):
    times = np.loadtxt(os.path.join(logging_dir, result_dirs[i], result_files[i]), delimiter=',', skiprows=1)
    print(f"{result_files[i].replace('time_test_','').replace('.txt','')}: {np.mean(times[1:])}")

Llama-2-7b-hf_MEBP_sgd: 35.85602824298703
opt-125m_adamw_torch: 2.632873214021021
opt-125m_sgd: 1.0848106097201913
Llama-2-7b-hf_sgd: 39.01320312947643
Llama-2-7b-hf_QLoRA_adamw_torch: 5.616802741070183
opt-125m_MEBP_adam: 1.323660736181298
opt-125m_MEBP_sgd: 1.1777102606637138
opt-125m_QLoRA_adamw_torch: 0.8773686812848461
