In [None]:
from transformers import AutoModelForCausalLM

from utils import *
from training import *
from validation import *

In [None]:
model_type = "llama3" # mistral, gemma, stablelm

hf_account_name = "" # huggingface.co username
save_name = "" # name to save the model as
model_name = "" # hf_repo/model_name
model_name_to_beat = model_name # set to the same unless comparing against a different model

params = load_local_config()
model = AutoModelForCausalLM.from_pretrained(model_name, **params, cache_dir="Models")
model.config.name_or_path = save_name
model = model.to("cuda")

tokenizer = get_tokenizer(model_type)

model = norm_model_weights(model)
        
base_model = AutoModelForCausalLM.from_pretrained(model_name_to_beat, **params, cache_dir="Models")
for name, param in base_model.named_parameters():
    param.requires_grad = False

trainer = Trainer(model, tokenizer, base_model)

In [None]:
# for name, param in model.named_parameters():
#     print(name, param, param.data.shape)

print(validate_parameters(model, print_vals=True))
print(validate_parameters(base_model, print_vals=True))

In [None]:
trainer.train(acc_batch_size=512, opt="adamw", lr=1e-5, lr_schedule="constant", weight_decay=0.0, betas=(0.9, 0.99), 
                warmup_steps=0, warmup_end_offset=0,
                grad_clip_norm=1.0, ignore_overshot_samples=True, bad_sample_mult=1.0, ignore_sample_loss_below=0.0, precalc_batch_mult=2.25,
                remerging=False, remerge_ratio=0.75,
                base_relative_loss=False, loss_eps = 0.02, overshoot_buffer = -0.01, eval_eps=0.01,
                eval_steps=512, revert=True, eval_revert_if={"loss": 0.004, "head_to_head": -12.5, "eps0_head_to_head": -22.5},
                save_name="test", do_save=True, cortex_steps=5, max_steps=None,
                gradient_checkpointing=False, excessive_cache_clearing=False, device="cuda")

In [None]:
validate_improvement(model, base_model, samples=768, tokenizer_name=model_type, dedup=False)

In [None]:
upload_name = hf_account_name + "/" + save_name
tokenizer.push_to_hub(repo_id=upload_name, private=True)
commit_info = model.push_to_hub(repo_id=upload_name, safe_serialization=True, private=True)