In [1]:
%load_ext autoreload
%autoreload 2

import torch
import transformers
import yaml
from src.utils import model_utils
from src.utils import quantized_model
from src.model import llama
from transformers import LlamaForCausalLM as OrigLlama
import os
from src import data
from src.utils import utils
import tqdm
import torch

In [2]:
# !export CUDA_VISIBLE_DEVICES=0,1
# !export CUDA_LAUNCH_BLOCKING=1
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# torch.distributed.init_process_group(backend='nccl')

In [3]:
#config
ft_n_train = 256
ft_n_val = 32
ft_dataset = "pajama"
base_model = "meta-llama/Llama-2-7b-hf"
seqlen = 4096
batch_size = 1
per_device_train_batch_size = 4
use_embedding = False
seed = 0


utils.seed(seed)

In [4]:
overall_data:list[torch.FloatTensor] = data.get_loaders(ft_dataset, nsamples = ft_n_train+ft_n_val
                                  , model = base_model, train_test = "train",
                                  seqlen=seqlen)

overall_data = torch.stack([_[0][0] for _ in overall_data])

Loading Red Pajama: 100%|██████████| 288/288 [00:23<00:00, 12.09it/s]


In [5]:
model = llama.LlamaForCausalLM.from_pretrained("/data/lliu/huffman/models/meta-llama/Llama-2-7b-hf/compressed_hf/run_38",
                                               device_map="auto",
                                                  torch_dtype=torch.float32,
                                                    low_cpu_mem_usage=True)

# model = model_utils.get_llama("meta-llama/Llama-2-7b-hf",
#                                device_map="auto",
#                                 dtype=torch.float32)
                          
    

device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codeb

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
model.model.layers[0].self_attn.k_proj.normalizer.zeros[0]

Parameter containing:
tensor([], device='cuda:0', requires_grad=True)

In [7]:
from torch.utils.data import Dataset
from typing import Tuple

class SimpleDataset(Dataset):

    def __init__(self, input_ids):
        self.input_ids = input_ids
    

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx], "labels": self.input_ids[idx]
        }
    
def make_datasets(X, n_val:int) -> Tuple[Dataset, Dataset]:
    

    #make the indices
    idxs = torch.randperm(len(X))
    train_idxs = idxs[:-n_val]

    train_ds = SimpleDataset(X[train_idxs])
    valid_ds = SimpleDataset(X[idxs[-n_val:]])
    return train_ds, valid_ds

traindataset, validdataset = make_datasets(overall_data, ft_n_val)

In [8]:
len(traindataset), len(validdataset)

(256, 32)

In [9]:
traindataset[0]

{'input_ids': tensor([  746,  1906, 14582,  ...,  3153, 13433,   470]),
 'labels': tensor([  746,  1906, 14582,  ...,  3153, 13433,   470])}

In [10]:


#custom kld loss
def custom_kld_loss(outputs, labels, num_items_in_batch):
    
    logits = outputs['logits'][:,:-1]
    print(logits.shape, labels.shape, num_items_in_batch)
    
    #do kld on the last dim
    loss = torch.nn.KLDivLoss(reduction='sum')(logits, labels)
    if num_items_in_batch == 0 or num_items_in_batch is None:
        l = loss/logits.numel()
        print(l)
        return l
    else:
        print(loss / num_items_in_batch)
        return loss / num_items_in_batch
    
    

In [13]:
#train the model on the dataset with transformers trainer

trainer = transformers.Trainer(
    model=model,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=1,
        gradient_checkpointing=True,
        fp16=True,
        logging_steps=1,
        output_dir="./output",
        num_train_epochs=10,
        save_total_limit=3,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        # eval_steps=100,
        # per_device_train_batch_size=32,
        load_best_model_at_end=True,
        #add a tqdm progress bar
        #set the lr to 1e-5
        learning_rate=1e-5,
        warmup_steps=100,
        lr_scheduler_type="cosine_with_restarts",
        lr_scheduler_kwargs={"num_cycles": 5},
        dataloader_pin_memory=False,
        #set the logging dir to ./logs
        logging_dir="./logs",
        #log to wandb
        report_to="wandb",
        run_name="llama-2-7b-hf",
        eval_on_start = True,
    ),
    train_dataset=traindataset,
    eval_dataset=validdataset,
)
#get the pre training performance
#train the model
trainer.train()



[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mm6481[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss
0,No log,1.967803
1,2.021700,1.904445
2,1.985700,1.890278
3,1.574100,1.889801
4,1.790800,1.891467
5,1.628000,1.893058


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


KeyboardInterrupt: 

In [None]:
model.save_pretrained("./output")   

In [14]:
#get the path to the best model
best_model = trainer.state.best_model_checkpoint

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x72202e117a10>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 722020ff37d0, raw_cell="#get the path to the best model
best_model = train.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B164.67.204.101/data/lliu/huffman/train_hard_labels.ipynb#X14sdnNjb2RlLXJlbW90ZQ%3D%3D>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x72202e117a10>> (for post_run_cell), with arguments args (<ExecutionResult object at 722020ff3710, execution_count=14 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 722020ff37d0, raw_cell="#get the path to the best model
best_model = train.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B164.67.204.101/data/lliu/huffman/train_hard_labels.ipynb#X14sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
model(valset[0]["input_ids"][...,:-2].cuda(),
      labels=valset[0]["input_ids"][...,1:].cuda()[...,:-1]  # shift labels
      ).loss

tensor(13.3660, device='cuda:0', grad_fn=<ToCopyBackward0>)

In [None]:
import glob
import tqdm
import torch

In [None]:
paths = glob.glob("/data/lliu/huffman/models/meta-llama/*/hessianDiags/seed_0/pajama/128/*/*.pt")
print(len(paths))

1848


In [None]:
for p in tqdm.tqdm(paths):
    hessianDiag = torch.load(p)
    if "hessianDiag" in hessianDiag:
        continue
    torch.save({"hessianDiag": hessianDiag["hessian"]}, p)

100%|██████████| 1848/1848 [00:01<00:00, 979.71it/s] 


: 

In [None]:
torch.load(p)

{'hessianDiag': tensor([0.0067, 0.0076, 0.0071,  ..., 0.0070, 0.0077, 0.0074], device='cuda:1',
        dtype=torch.float16)}

In [None]:
class A:
    
    def __init__(self):
        self.a = 1
        self.b = 2
        self.c = 3
        
    def fn1(self):
        print("fn1_A")
        
    @classmethod
    def fn1_static(cls):
        c = cls()
        c.fn1()
        return c
    
class B(A):
    
    def __init__(self):
        super().__init__()
        print("initializing B")
        self.d = 4
        
    def fn1(self):
        print("fn1_B")
        
B.fn1_static()
    


initializing B
fn1_B


<__main__.B at 0x7a053f3f42f0>

: 