In [11]:
import os
import itertools

In [12]:
def mkdir(dir):
    if not os.path.exists(dir):
        os.mkdir(dir)

In [13]:
# global job parameters

job_directory = f"fineweb_edu"
out_dir = f'{job_directory}/.out'
job_duration = '00-48:00:00'
partition = 'gpu'
ntasks = 1
nodes = 1
cpu_per_gpu = 8
mem_per_cpu = 16
# n_gpus = 1
# gpus_constraints = '"h100|a100"' # all gpus are pretty good now
project_dir = "/home/ma2393/project/abstract_transformer/experiments/fineweb"

mkdir(job_directory)
mkdir(out_dir)

In [14]:
# model params
T = 1024
total_batch_size = 524_288

model_params = [
    # dict(d_model=1024, n_layers=24, sa=8, ra=8,
    #     sym_attn_n_symbols=1024, sym_attn_n_heads=8,
    #     n_kv_heads=1, B=8,
    #     n_gpus=2, gpus_constraints= '"h100|a100"',
    #     param_ct_string='325M'),
    # dict(d_model=1024, n_layers=24, sa=8, ra=8,
    #     sym_attn_n_symbols=1024, sym_attn_n_heads=8,
    #     n_kv_heads=2, B=8,
    #     n_gpus=2, gpus_constraints= '"h100|a100"',
    #     param_ct_string='330M'),
    # dict(d_model=1024, n_layers=24, sa=8, ra=8,
    #     sym_attn_n_symbols=1024, sym_attn_n_heads=8,
    #     n_kv_heads=4, B=8,
    #     n_gpus=2, gpus_constraints= '"h100|a100"',
    #     param_ct_string='343M'),
    dict(d_model=1024, n_layers=24, sa=8, ra=8,
        sym_attn_n_symbols=1024, sym_attn_n_heads=8,
        share_attn_params=1, B=8,
        n_gpus=2, gpus_constraints= '"h100|a100"',
        param_ct_string='343M'),
]

jobs_params = []
for mparams in model_params:
    # compute run name
    if mparams['ra'] > 0:
        run_name = f"DAT-sa{mparams['sa']}-ra{mparams['ra']}"
        if 'n_relations' in mparams:
            run_name += f"-nr{mparams['n_relations']}"
        if 'share_attn_params' in mparams:
            run_name += f"-sharedattn{mparams['share_attn_params']}"
        if 'sym_attn_n_symbols' in mparams:
            run_name += f"-ns{mparams['sym_attn_n_symbols']}"
        if 'sym_attn_n_heads' in mparams:
            run_name += f"-sh{mparams['sym_attn_n_heads']}"
        if 'shared_symbol_retriever' in mparams:
            run_name += f"-ssr{mparams['shared_symbol_retriever']}"
        if 'weight_tie_symbol_library' in mparams:
            run_name += f"-wt{mparams['weight_tie_symbol_library']}"
        if 'trainable_symbols' in mparams:
            run_name += f"-ts{mparams['trainable_symbols']}"
    else:
        run_name = f'T-sa{mparams["sa"]}'
    if 'n_kv_heads' in mparams:
        run_name += f'-nkvh{mparams["n_kv_heads"]}'
    if 'param_ct_string' in mparams:
        run_name += f'-{mparams["param_ct_string"]}'

    jobs_params.append({**mparams, 'run_name': run_name})

In [15]:
jobs_params

[{'d_model': 1024,
  'n_layers': 24,
  'sa': 8,
  'ra': 8,
  'sym_attn_n_symbols': 1024,
  'sym_attn_n_heads': 8,
  'share_attn_params': 1,
  'B': 8,
  'n_gpus': 2,
  'gpus_constraints': '"h100|a100"',
  'param_ct_string': '343M',
  'run_name': 'DAT-sa8-ra8-sharedattn1-ns1024-sh8-343M'}]

In [16]:
len(jobs_params)

1

In [17]:
# global config parameters
n_epochs = 1
max_steps = -1
log_to_wandb = 1

In [18]:
# create jobs
created_jobs = []
for params in jobs_params:

    job_file = os.path.join(job_directory, f"{params['run_name']}.job")

    with open(job_file, 'w') as fh:
        fh.writelines(f"#!/bin/bash\n")
        fh.writelines(f"#SBATCH --partition={partition}\n")
        fh.writelines(f"#SBATCH --job-name={params['run_name']}\n")
        fh.writelines(f"#SBATCH --output={out_dir}/%j-{params['run_name']}.out\n")
        fh.writelines(f"#SBATCH --ntasks={ntasks} --nodes={nodes}\n")
        fh.writelines(f"#SBATCH --cpus-per-gpu={cpu_per_gpu}\n")
        fh.writelines(f"#SBATCH --mem-per-cpu={mem_per_cpu}G\n")
        fh.writelines(f"#SBATCH --time={job_duration}\n")
        fh.writelines(f"#SBATCH --mail-type=ALL\n")
        fh.writelines(f"#SBATCH --gpus={params['n_gpus']}\n")
        if 'gpus_constraints' in params:
            fh.writelines(f"#SBATCH --constraint={params['gpus_constraints']}\n")

        fh.writelines('\n')
        fh.writelines('module load StdEnv\n')
        fh.writelines('export SLURM_EXPORT_ENV=ALL\n')
        fh.writelines('\n')

        # fh.writelines(f"module restore python_env\n") # load modules i need
        fh.writelines(f"module load miniconda\n") # load modules i need
        # fh.writelines(f"conda init\n")
        fh.writelines(f"conda activate abstract_transformer\n") # activate conda environment
        fh.writelines(f"conda info --envs\n") # activate conda environment

        fh.writelines('\n')
        fh.writelines(f"nvidia-smi -L\n") # print gpu information
        fh.writelines('\n')

        fh.writelines(f"cd {project_dir}\n") # navigate to project directory
        fh.writelines('\n')

        # run python script
        if params['n_gpus'] > 1:
            fh.writelines(f"torchrun --standalone --nproc_per_node={params['n_gpus']} pretrain.py \\\n")
        else:
            fh.writelines(f"python pretrain.py \\\n")

        fh.writelines(f"\t--d_model {params['d_model']} --sa {params['sa']} --ra {params['ra']} --n_layers {params['n_layers']} \\\n")
        if 'n_relations' in params:
            fh.writelines(f"\t--n_relations {params['n_relations']} \\\n")
        if 'sym_attn_n_symbols' in params:
            fh.writelines(f"\t--sym_attn_n_symbols {params['sym_attn_n_symbols']} --sym_attn_n_heads {params['sym_attn_n_heads']} \\\n")
        if 'n_kv_heads' in params:
            fh.writelines(f"\t--n_kv_heads {params['n_kv_heads']} \\\n")
        if 'shared_symbol_retriever' in params:
            fh.writelines(f"\t--shared_symbol_retriever {params['shared_symbol_retriever']} --weight_tie_symbol_library {params['weight_tie_symbol_library']} ")
            fh.writelines(f"--trainable_symbols {params['trainable_symbols']} \\\n")
        fh.writelines(f"\t--T {T} --B {params['B']} --total_batch_size {total_batch_size} \\\n")
        fh.writelines(f"\t--wandb_log 1 --run_name {params['run_name']} --job_duration {job_duration} \\\n")

    created_jobs.append(job_file)

In [19]:
created_jobs

['fineweb_edu/DAT-sa8-ra8-sharedattn1-ns1024-sh8-343M.job']

In [21]:
confirm = input("CONTINUE TO RUN ALL JOBS?")
if confirm == 'y':
    for job in created_jobs:
        os.system(f'sbatch {job}')
else:
    print('JOBS NOT SUBMITTED')

Submitted batch job 22918
