In [9]:
import os
import itertools
from dotenv import load_dotenv
import copy
import time
import subprocess
import yaml


import os, sys; sys.path.insert(0, os.path.abspath('../..')) # add project root dir to path
from fineweb.model_recurrent import get_experiment_name
from utils.utils import AttributeDict

load_dotenv()

def mkdir(dir):
    if not os.path.exists(dir):
        os.mkdir(dir)


In [10]:
# global job parameters

job_directory = f"job_scripts/recurrent"
out_dir = f'.out'

time_str = '00-24:00:00'
max_time = '00:23:55:00' # 5 minutes less than the time_str; this is the format PL uses

partition = 'gpu'
ntasks = 1
nodes = 1
cpu_per_gpu = 8
mem_per_cpu = 8
n_gpus = 1

cluster = 'misha'

if cluster == 'grace':
    gpus_constraints = '"a100|rtx3090|v100|rtx2080ti"' # for grace
# gpus_constraints = "a40" #'"h100|a100"' # for misha

netid = os.getenv('NETID')
project_dir = f"/home/{netid}/project/adaptive-hyperspherical-res-stream/fineweb"

mkdir(job_directory)
mkdir(out_dir)

In [11]:
# load base model, train, and data config
import yaml
configs_dir = f'{project_dir}/configs/recurrent'
base_config_dir = f'{configs_dir}/base_config'

with open(os.path.join(base_config_dir, 'model_config.yaml')) as f:
    base_model_config = AttributeDict(yaml.load(f, Loader=yaml.FullLoader))

with open(os.path.join(base_config_dir, 'train_config.yaml')) as f:
    base_train_config = AttributeDict(yaml.load(f, Loader=yaml.FullLoader))

with open(os.path.join(base_config_dir, 'data_config.yaml')) as f:
    base_data_config = AttributeDict(yaml.load(f, Loader=yaml.FullLoader))


In [12]:
D, L, T, H = 256, 2, 1, 4
sequence_length_map = {256: 256, 384: 512, 512: 512} # map from d_model to sequence_length

manual_norm_weights = True
micro_batch_size = 64

# model_types = ['llama', 'nGPT', 'baseline_transformer'] # 'llama', 'nGPT'
model_types = ['nGPT'] # 'llama', 'nGPT'

model_archs = [
        # dict(n_layers=L, d_model=D, n_heads=H, n_iters=T),

        dict(d_model=256, n_heads=4, n_layers=1, n_iters=6),
        dict(d_model=256, n_heads=4, n_layers=2, n_iters=3),
        dict(d_model=256, n_heads=4, n_layers=6, n_iters=1),

        dict(d_model=384, n_heads=8, n_layers=1, n_iters=8),
        dict(d_model=384, n_heads=8, n_layers=2, n_iters=4),
        dict(d_model=384, n_heads=8, n_layers=8, n_iters=1),

        dict(d_model=512, n_heads=8, n_layers=1, n_iters=6),
        dict(d_model=512, n_heads=8, n_layers=2, n_iters=4),
        dict(d_model=512, n_heads=8, n_layers=8, n_iters=1),
        ]

model_type_configs = dict(
    nGPT = dict(
        residual_module_args = [
            dict(residual_module='ResidualSphericalLERPBase'),

            # dict(residual_module='ResidualSphericalSLERP', residual_module_kwargs=dict(single_weight=True)),
            # dict(residual_module='ResidualSphericalSLERP', residual_module_kwargs=dict(single_weight=False)),
            # dict(residual_module='ResidualSphericalSLERP', residual_module_kwargs=dict(single_weight=True, n_spheres=H)),
            # dict(residual_module='ResidualSphericalSLERP', residual_module_kwargs=dict(single_weight=False, n_spheres=H)),

            # dict(residual_module='ResidualAdaptiveSphericalSLERP',
            #         residual_module_kwargs=dict(single_weight=True, slerp_weight_map='NormLinear', interpolation_weight_activation='linear')),
            # dict(residual_module='ResidualAdaptiveSphericalSLERP',
            #     residual_module_kwargs=dict(single_weight=False, slerp_weight_map='NormLinear', interpolation_weight_activation='linear')),

            # dict(residual_module='ResidualAdaptiveSphericalSLERP',
            #     residual_module_kwargs=dict(single_weight=True, slerp_weight_map='NormLinear', interpolation_weight_activation='sigmoid')),
            # dict(residual_module='ResidualAdaptiveSphericalSLERP',
            #     residual_module_kwargs=dict(single_weight=False, slerp_weight_map='NormLinear', interpolation_weight_activation='sigmoid')),

            # dict(residual_module='ResidualAdaptiveSphericalSLERP',
            #     residual_module_kwargs=dict(single_weight=True, slerp_weight_map='NormLinear', interpolation_weight_activation='sigmoid', bias=True)),
            # dict(residual_module='ResidualAdaptiveSphericalSLERP',
            #     residual_module_kwargs=dict(single_weight=False, slerp_weight_map='NormLinear', interpolation_weight_activation='sigmoid', bias=True)),

            dict(residual_module='ResidualAdaptiveSphericalSLERP',
                residual_module_kwargs=dict(single_weight=True, n_spheres=H, slerp_weight_map='NormLinear', interpolation_weight_activation='sigmoid')),
            # dict(residual_module='ResidualAdaptiveSphericalSLERP',
                # residual_module_kwargs=dict(single_weight=False, n_spheres=H, slerp_weight_map='NormLinear', interpolation_weight_activation='sigmoid')),
            ],
            manual_norm_weights = [manual_norm_weights],
            gpt_special_init = [True]
        ),

    llama = dict(),

    baseline_transformer = dict(
        norm_config = [dict(norm_method='pre-norm', norm_type='rmsnorm')],
        mlp_activation = ['swiglu'],
        pos_enc_type = ['rotary'],
        bias = [False],
        gpt_special_init = [False, True],
    ),

)


wandb_project = 'recurrent-language-modeling-nGPT'


In [13]:
jobs_overwrite_params = []

print('number of model architecture configs: ', len(model_archs))

for model_type in model_types:
    print()
    print('creating jobs for model_type =', model_type)

    model_type_config_product = list(itertools.product(*[[(k, v) for v in vs] for k, vs in model_type_configs[model_type].items()]))

    print('number of model_type_configs:', len(model_type_config_product))

    for model_arch_config, model_type_config in itertools.product(model_archs, model_type_config_product):
        # copy base configs
        job_model_config = copy.deepcopy(base_model_config)
        job_train_config = copy.deepcopy(base_train_config)
        job_data_config = copy.deepcopy(base_data_config)

        # set_model_type
        job_model_config['model_type'] = model_type

        # update model config
        job_model_config.update(model_arch_config) # update with model_architecture

        # update kwargs for model_type
        model_type_config_update = dict(model_type_config)
        if f'{model_type}_kwargs' in job_model_config:
            job_model_config[f'{model_type}_kwargs'].update(model_type_config_update)
        else:
            job_model_config[f'{model_type}_kwargs'] = model_type_config_update

        # remove other model_type kwargs
        for other_model_type in ['nGPT', 'llama', 'baseline_transformer']:
            if other_model_type != model_type and f'{other_model_type}_kwargs' in job_model_config:
                del job_model_config[f'{other_model_type}_kwargs']

        # parse train_cofig
        job_train_config['wandb_config'] = job_train_config['wandb_config'] | dict(wandb_project=wandb_project)

        job_train_config['max_time'] = max_time

        job_train_config['micro_batch_size'] = micro_batch_size

        # set learning rate schedule config
        job_train_config['cosine_scheduler_config'] = dict(warmup_steps=0, max_steps=None) # test warmup_steps=0, for normalized models

        # update data config
        job_data_config['sequence_length'] = sequence_length_map[job_model_config['d_model']]

        job_config = dict(model_config=job_model_config, train_config=job_train_config, data_config=job_data_config)
        job_config = AttributeDict(job_config)
        jobs_overwrite_params.append(job_config)

print('\ntotal number of jobs:', len(jobs_overwrite_params))

number of model architecture configs:  9

creating jobs for model_type = nGPT
number of model_type_configs: 2

total number of jobs: 18


In [15]:
def create_job_config(job_configs, out_dir, uid=None):
    # global base_model_config, base_train_config, base_data_config
    # model_config, train_config, data_config = tuple(copy.deepcopy(c) for c in (base_model_config, base_train_config, base_data_config))

    # model_config.update(config_upate.get('model_config', {}))
    # train_config.update(config_upate.get('train_config', {}))
    # data_config.update(config_upate.get('data_config', {}))
    model_config, train_config, data_config = job_configs.model_config, job_configs.train_config, job_configs.data_config

    experiment_name, _ = get_experiment_name(model_config, data_config, train_config)
    experiment_name = experiment_name.replace(' ', '')
    if uid is not None:
        experiment_name = f"UID{uid}-{experiment_name}"

    mkdir(os.path.join(out_dir, experiment_name))

    with open(os.path.join(out_dir, f'{experiment_name}/model_config.yaml'), 'w') as f:
        yaml.dump(model_config.todict(), f)

    with open(os.path.join(out_dir, f'{experiment_name}/train_config.yaml'), 'w') as f:
        yaml.dump(train_config.todict(), f)

    with open(os.path.join(out_dir, f'{experiment_name}/data_config.yaml'), 'w') as f:
        yaml.dump(data_config.todict(), f)

    return model_config, train_config, data_config, experiment_name

In [16]:
def create_job_script(experiment_name):
    filename = f'{job_directory}/{experiment_name}.job'
    with open(filename, 'w') as fh:
        fh.writelines(f"#!/bin/bash\n")
        fh.writelines(f"#SBATCH --partition={partition}\n")
        fh.writelines(f"#SBATCH --job-name={experiment_name}\n")
        fh.writelines(f"#SBATCH --output={out_dir}/%j-{experiment_name}.out\n")
        fh.writelines(f"#SBATCH --ntasks={ntasks} --nodes={nodes}\n")
        if cluster == 'misha':
            fh.writelines(f"#SBATCH --cpus-per-gpu={cpu_per_gpu}\n")
        else:
            fh.writelines(f"#SBATCH --cpus-per-task={cpu_per_gpu * n_gpus}\n")
        fh.writelines(f"#SBATCH --mem-per-cpu={mem_per_cpu}G\n")
        fh.writelines(f"#SBATCH --time={time_str}\n")
        fh.writelines(f"#SBATCH --mail-type=ALL\n")
        fh.writelines(f"#SBATCH --gpus={n_gpus}\n")
        if gpus_constraints is not None:
            fh.writelines(f"#SBATCH --constraint={gpus_constraints}\n")

        fh.writelines('\n')
        fh.writelines('module load StdEnv\n')
        fh.writelines('export SLURM_EXPORT_ENV=ALL\n')
        fh.writelines('\n')

        if cluster == 'grace':
            fh.writelines(f"module restore python_env\n") # load modules i need
        elif cluster == 'misha':
            fh.writelines(f"module load miniconda\n") # load modules i need
        else:
            raise ValueError(f"Cluster {cluster} not supported")

        # fh.writelines(f"conda init\n")
        fh.writelines(f"conda activate neural_prog\n") # activate conda environment
        fh.writelines(f"conda info --envs\n") # activate conda environment

        fh.writelines('\n')
        fh.writelines(f"nvidia-smi -L\n") # print gpu information
        fh.writelines('\n')

        fh.writelines(f"cd {project_dir}\n") # navigate to project directory
        fh.writelines('\n')

        # run python script
        fh.writelines(f"srun python train_recurrent.py --config_dir {configs_dir}/{experiment_name}\n") # run python script

    return filename


In [17]:
job_script_files = []

for uid, job_params in enumerate(jobs_overwrite_params):
    base_model_config, base_train_config, base_data_config, experiment_name = create_job_config(job_params, configs_dir, uid=uid)

    print(f"Experiment Name: {experiment_name}")

    job_script = create_job_script(experiment_name)
    job_script_files.append(job_script)

Experiment Name: UID0-nGPT-L1T6H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256
Experiment Name: UID1-nGPT-L1T6H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256
Experiment Name: UID2-nGPT-L2T3H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256
Experiment Name: UID3-nGPT-L2T3H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256
Experiment Name: UID4-nGPT-L6T1H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256
Experiment Name: UID5-nGPT-L6T1H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256
Experiment Name: UID6-nGPT-L1T8H8D384-ResidualSphericalLERPBase-MNW-True-GPTInit-True-512
Experiment Name: UID7-nGPT-L1T8H8D384-ResidualSphericalLERPBase-MNW-True-GPTInit-True-512
Experiment Name: UID8-nGPT-L2T4H8D384-ResidualSphericalLERPBase-MNW-True-GPTInit-True-512
Experiment Name: UID9-nGPT-L2T4H8D384-ResidualSphericalLERPBase-MNW-True-GPTInit-True-512
Experiment Name: UID10-nGPT-L8T1H8D384-ResidualSphericalLERPBase-MNW-True-GPTInit-True-512
Experimen

In [18]:
wait_time = 0.5 # number of seconds to wait between job submissions
n_trials = 1

confirm = input("Do you want to submit the jobs? (y/n): ")

responses = []

if confirm == 'y':
    for ir in range(n_trials):
        print('Trial:', ir)
        for job_script in job_script_files:
            response = subprocess.run(['sbatch', job_script], capture_output=True)
            print(f"response: {response.stdout.decode('utf-8').strip()}, return_code={response.returncode}, job_script={job_script}")
            responses.append(response)
            time.sleep(wait_time)
        print()
else:
    print("Not submitting jobs")

Trial: 0
response: Submitted batch job 140675, return_code=0, job_script=job_scripts/recurrent/UID0-nGPT-L1T6H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256.job
response: Submitted batch job 140676, return_code=0, job_script=job_scripts/recurrent/UID1-nGPT-L1T6H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256.job
response: Submitted batch job 140677, return_code=0, job_script=job_scripts/recurrent/UID2-nGPT-L2T3H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256.job
response: Submitted batch job 140678, return_code=0, job_script=job_scripts/recurrent/UID3-nGPT-L2T3H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256.job
response: Submitted batch job 140679, return_code=0, job_script=job_scripts/recurrent/UID4-nGPT-L6T1H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256.job
response: Submitted batch job 140680, return_code=0, job_script=job_scripts/recurrent/UID5-nGPT-L6T1H4D256-ResidualSphericalLERPBase-MNW-True-GPTInit-True-256.job
response: Sub

In [19]:
# check if any jobs failed to submit
for response in responses:
    if not response.stdout.decode('utf-8').startswith('Submitted batch job') or response.returncode != 0:
        print(f"Failed to submit job: {response.stdout.decode('utf-8')}")
        print(f"stderr: {response.stderr.decode('utf-8')}")
        print(f"Full response: {response}")
        print()