In [3]:
import torch
import transformers

from src.utils import model_utils
from src import data

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
!export CUDA_VISIBLE_DEVICES=2,3,4,5,6,7
!export CUDA_LAUNCH_BLOCKING=1

In [5]:
model = model_utils.get_llama("meta-llama/Llama-2-7b-hf", device_map="auto",
                              dtype=torch.float32)

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.90s/it]


Model loaded. LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)

In [6]:
dataset = data.get_loaders("pajama", seqlen = model.config.max_position_embeddings, 
                           model="meta-llama/Llama-2-7b-hf")
                           

Loading Red Pajama: 100%|██████████| 128/128 [00:14<00:00,  9.03it/s]


In [7]:
trainset = dataset[:100]
valset = dataset[100:]

#convert both to torch dataset
class simple_dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"input_ids": self.data[idx][0][...,:-2].squeeze(0), "labels": self.data[idx][0][...,1:][...,:-1].squeeze(0)}

trainset = simple_dataset(trainset)
valset = simple_dataset(valset)

In [8]:
#train the model on the dataset with transformers trainer

trainer = transformers.Trainer(
    model=model,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        gradient_checkpointing=True,
        fp16=True,
        logging_steps=1,
        save_steps=100,
        output_dir="./output",
        num_train_epochs=1,
        save_total_limit=3,
        evaluation_strategy="no",
        eval_steps=100,
        num_train_epochs=100,
        # per_device_train_batch_size=32,
        logging_steps=1,
        load_best_model_at_end=False,
        remove_unused_columns=False,
        #add a tqdm progress bar
        report_to="none",
    ),
    train_dataset=trainset,
    eval_dataset=valset,
)

trainer.train()





`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,11.451
2,6.9622
3,6.0308
4,5.9486
5,5.3535


KeyboardInterrupt: 



In [24]:
model(valset[0]["input_ids"][...,:-2].cuda(),
      labels=valset[0]["input_ids"][...,1:].cuda()[...,:-1]  # shift labels
      ).loss

tensor(13.3660, device='cuda:0', grad_fn=<ToCopyBackward0>)

In [6]:
import glob
import tqdm
import torch

In [7]:
paths = glob.glob("/data/lliu/huffman/models/meta-llama/*/hessianDiags/seed_0/pajama/128/*/*.pt")
print(len(paths))

1848


In [10]:
for p in tqdm.tqdm(paths):
    hessianDiag = torch.load(p)
    if "hessianDiag" in hessianDiag:
        continue
    torch.save({"hessianDiag": hessianDiag["hessian"]}, p)

100%|██████████| 1848/1848 [00:01<00:00, 979.71it/s] 


In [9]:
torch.load(p)

{'hessianDiag': tensor([0.0067, 0.0076, 0.0071,  ..., 0.0070, 0.0077, 0.0074], device='cuda:1',
        dtype=torch.float16)}