In [1]:
import torch

import sys; sys.path.append('..')
from language_models import DualAttnTransformerLM, TransformerLM
from hf import DualAttnTransformerLM_HFHub
from huggingface_hub import ModelCard, ModelCardData
from datetime import datetime

In [2]:
device = 'cpu' # torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_from_ckpt(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    model_config = ckpt['config']

    model_state_dict = ckpt['model']
    model_state_dict = {k.replace('_orig_mod.', ''): v for k, v in model_state_dict.items()}

    if 'n_heads_ra' in model_config:
        model = DualAttnTransformerLM(**model_config)
    else:
        model = TransformerLM(**model_config)

    model.load_state_dict(model_state_dict)

    return model

def load_from_ckpt_hf(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    model_config = ckpt['config']

    model_state_dict = ckpt['model']
    model_state_dict = {k.replace('_orig_mod.', ''): v for k, v in model_state_dict.items()}

    if 'n_heads_ra' in model_config:
        model = DualAttnTransformerLM_HFHub(**model_config)
    else:
        model = TransformerLM(**model_config)

    model.load_state_dict(model_state_dict)

    return model

In [3]:
base_path = '../experiments/fineweb/log'

In [4]:
model_paths = [
    # f'{base_path}/DAT-sa8-ra8-ns1024-sh8-nkvh4-343M_2024_07_19_13_50_14_resumed_2024_07_26_18_49_04/model_19073.pt',
    # f'{base_path}/DAT-sa8-ra8-nr64-ns1024-sh8-nkvh4-343M_2024_07_30_13_58_00_resumed_2024_08_14_19_34_08/model_19073.pt',
    # f'{base_path}/DAT-sa8-ra8-nr32-ns1024-sh8-nkvh4-343M_2024_07_30_16_55_13_resumed_2024_08_14_19_34_16/model_19073.pt',

    f'{base_path}/DAT-sa16-ra16-nr64-ns2048-sh8-nkvh8-1.27B_2024_07_28_00_48_29/model_19073.pt',
    f'{base_path}/DAT-sa16-ra16-nr128-ns2048-sh16-nkvh8-1.27B_2024_07_31_08_52_58_resumed_2024_08_19_15_26_57/model_19073.pt',
]

In [5]:
def create_model_card(model, model_name):
    n_layers = model.n_layers
    block_size = model.block_size
    d_model = model.d_model
    n_heads_sa = model.n_heads_sa
    n_heads_ra = model.n_heads_ra
    rel_dim = model.layers.blocks[0].dual_attn.relational_attention.n_relations
    training_tokens = '10B'
    tokenizer = 'GPT-2 BPE tokenizer'
    dataset = 'HuggingFaceFW/fineweb-edu'
    msize = sum(p.numel() for p in model.parameters() if p.requires_grad)
    if msize > 1e9:
        msize = f'{msize/1e9:.0f}B'
    elif msize > 1e6:
        msize = f'{msize/1e6:.0f}M'
    else:
        raise NotImplementedError()

    template = dict(
        model_name=model_name,
        n_layers=n_layers,
        block_size=block_size,
        d_model=d_model,
        n_heads_sa=n_heads_sa,
        n_heads_ra=n_heads_ra,
        rel_dim=rel_dim,
        training_tokens=training_tokens,
        tokenizer=tokenizer,
        dataset=dataset,
        msize=msize,
        date=datetime.now().strftime('%B, %Y')
    )

    card_data = ModelCardData(
        language="en", license="mit", dataset=dataset, pipeline_tag="text-generation", tags=["model_hub_mixin", "pytorch_model_hub_mixin"])


    model_card = ModelCard.from_template(card_data, 'readme_template.md', **template)
    return model_card

In [6]:
for model_path in model_paths:
    print('='*80)
    model = load_from_ckpt_hf(model_path)
    model_name = model_path.split('/')[-2]
    model_name = model_name.split('_')[0]
    print(f'model_name: {model_name}')
    print(f'model_path: {model_path}')

    confirm = input('Confirm? (y/n): ')
    if confirm == 'y':
        model.push_to_hub(model_name)
        create_model_card(model, model_name).push_to_hub(f'awni00/{model_name}')
    else:
        print('Model not pushed.')


model_name: DAT-sa16-ra16-nr64-ns2048-sh8-nkvh8-1.27B
model_path: ../experiments/fineweb/log/DAT-sa16-ra16-nr64-ns2048-sh8-nkvh8-1.27B_2024_07_28_00_48_29/model_19073.pt


model.safetensors:   0%|          | 0.00/5.10G [00:00<?, ?B/s]

model_name: DAT-sa16-ra16-nr128-ns2048-sh16-nkvh8-1.27B
model_path: ../experiments/fineweb/log/DAT-sa16-ra16-nr128-ns2048-sh16-nkvh8-1.27B_2024_07_31_08_52_58_resumed_2024_08_19_15_26_57/model_19073.pt


model.safetensors:   0%|          | 0.00/5.11G [00:00<?, ?B/s]