In [None]:
%load_ext autoreload
%autoreload 2

import torch
import transformers
import yaml
from src.utils import model_utils
from src.utils import quantized_model
from src.model import llama
from transformers import LlamaForCausalLM as OrigLlama
import os
from src import data
import tqdm
import torch

In [2]:
# !export CUDA_VISIBLE_DEVICES=0,1
# !export CUDA_LAUNCH_BLOCKING=1
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# torch.distributed.init_process_group(backend='nccl')

In [3]:
#config
ft_n_train = 16
ft_n_val = 8
ft_dataset = "pajama"
base_model = "meta-llama/Llama-2-7b-hf"
seqlen = 4096
batch_size = 1
per_device_train_batch_size = 2
use_embedding = False


In [4]:
orig_model = OrigLlama.from_pretrained(base_model,
                                       device_map="auto", torch_dtype=torch.float32)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
overall_data:list[torch.FloatTensor] = data.get_loaders(ft_dataset, nsamples = ft_n_train+ft_n_val
                                  , model = base_model, train_test = "train",
                                  seqlen=seqlen)

overall_data = torch.stack([_[0][0] for _ in overall_data])

Loading Red Pajama: 100%|██████████| 24/24 [00:01<00:00, 18.14it/s]


In [6]:
@torch.no_grad()
def calculate_logits(model: llama.LlamaForCausalLM, devset, batch_size):
    logits = []
    for i in tqdm.tqdm(range(len(devset) // batch_size), desc = "Calculating logits"):
        logits.append(
            model(devset[i * batch_size:(i + 1) *
                         batch_size].cuda())['logits'].cpu())
    logits = torch.concat(logits, dim=0)
    return logits

In [7]:
overall_out = calculate_logits(orig_model,overall_data, batch_size)

Calculating logits:   0%|          | 0/24 [00:00<?, ?it/s]

Calculating logits: 100%|██████████| 24/24 [01:31<00:00,  3.81s/it]


In [8]:
overall_out.shape

torch.Size([24, 4096, 32000])

In [9]:
overall_out = overall_out[:, :-1].contiguous().softmax(dim=-1).float()

In [10]:
overall_out

tensor([[[2.2354e-10, 5.7025e-09, 1.6306e-05,  ..., 1.0361e-08,
          1.3468e-09, 3.6893e-09],
         [1.4929e-10, 1.2274e-09, 1.1974e-05,  ..., 2.2617e-08,
          7.3949e-09, 2.8877e-09],
         [2.0296e-10, 6.5319e-10, 8.4720e-06,  ..., 1.0241e-08,
          4.6695e-09, 1.4944e-09],
         ...,
         [1.3458e-11, 4.6223e-14, 3.6947e-07,  ..., 1.2013e-10,
          1.8254e-09, 1.7071e-09],
         [4.1465e-12, 1.3264e-13, 8.3513e-07,  ..., 1.8017e-10,
          3.2149e-10, 4.0146e-10],
         [1.6634e-10, 4.3322e-12, 1.4700e-07,  ..., 1.1152e-09,
          9.0015e-10, 4.6064e-08]],

        [[2.3169e-09, 2.6811e-09, 1.5956e-04,  ..., 2.8206e-08,
          1.2517e-08, 6.9140e-09],
         [2.9108e-08, 1.1318e-08, 3.8498e-04,  ..., 2.5980e-07,
          5.6099e-07, 1.5601e-07],
         [6.9376e-09, 4.8868e-08, 2.5782e-05,  ..., 2.3716e-07,
          6.9378e-08, 7.1069e-08],
         ...,
         [4.4877e-13, 1.4639e-13, 1.6864e-06,  ..., 5.0638e-11,
          8.654

In [11]:

del orig_model

In [12]:
from src.utils import utils



utils.clean()

In [13]:
model = llama.LlamaForCausalLM.from_pretrained("/data/lliu/huffman/models/meta-llama/Llama-2-7b-hf/compressed_hf/run_38",
                                               device_map="auto",
                                                  torch_dtype=torch.float32,
                                                    low_cpu_mem_usage=True)

# model = model_utils.get_llama("meta-llama/Llama-2-7b-hf",
#                                device_map="auto",
#                                 dtype=torch.float32)
                          
    

device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float32
device None dtype torch.float32
codeb

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
with torch.no_grad():
    temp_out = model(overall_data[[0]].cuda())['logits'].contiguous().softmax(dim=-1).float()

In [22]:
temp_out

tensor([[[4.6737e-07, 2.1708e-05, 6.8499e-04,  ..., 4.6186e-07,
          2.2432e-07, 1.3042e-06],
         [2.0906e-10, 6.5542e-10, 9.4043e-06,  ..., 2.3275e-08,
          5.9380e-09, 2.9740e-09],
         [1.7926e-10, 3.2220e-10, 7.7861e-06,  ..., 1.0558e-08,
          3.6094e-09, 1.0416e-09],
         ...,
         [1.5693e-11, 3.6049e-13, 2.3453e-06,  ..., 5.4415e-10,
          1.1082e-09, 5.7102e-10],
         [5.7886e-10, 4.3732e-11, 1.7115e-06,  ..., 3.1174e-09,
          9.5443e-09, 9.1943e-08],
         [9.4791e-10, 2.3938e-11, 1.8201e-07,  ..., 2.9170e-09,
          5.9034e-08, 2.2966e-08]]], device='cuda:0')

In [21]:
overall_out[[0]]

tensor([[[2.2354e-10, 5.7025e-09, 1.6306e-05,  ..., 1.0361e-08,
          1.3468e-09, 3.6893e-09],
         [1.4929e-10, 1.2274e-09, 1.1974e-05,  ..., 2.2617e-08,
          7.3949e-09, 2.8877e-09],
         [2.0296e-10, 6.5319e-10, 8.4720e-06,  ..., 1.0241e-08,
          4.6695e-09, 1.4944e-09],
         ...,
         [1.3458e-11, 4.6223e-14, 3.6947e-07,  ..., 1.2013e-10,
          1.8254e-09, 1.7071e-09],
         [4.1465e-12, 1.3264e-13, 8.3513e-07,  ..., 1.8017e-10,
          3.2149e-10, 4.0146e-10],
         [1.6634e-10, 4.3322e-12, 1.4700e-07,  ..., 1.1152e-09,
          9.0015e-10, 4.6064e-08]]])

In [25]:
kld = torch.nn.KLDivLoss(reduction="batchmean")
kld(temp_out[:,:-1], overall_out[[0]].cuda())

tensor(-10581.0996, device='cuda:0')

In [14]:
overall_embeddings = model.model.embed_tokens(overall_data.cuda()).detach()
                           

In [15]:
overall_embeddings.requires_grad

False

In [16]:
overall_out.shape

torch.Size([24, 4095, 32000])

In [17]:
from transformers.modeling_attn_mask_utils import \
    _prepare_4d_causal_attention_mask
    
position_ids = torch.arange(seqlen, dtype=torch.int32)[None, :] + \
    torch.zeros(per_device_train_batch_size, seqlen, dtype=torch.int32)
attention_mask = _prepare_4d_causal_attention_mask(
    None, (per_device_train_batch_size, seqlen), overall_embeddings[:per_device_train_batch_size], 0)   

In [18]:
from torch.utils.data import Dataset
from typing import Tuple

class SimpleDataset(Dataset):

    def __init__(self, inputs_embeds, soft_labels,attention_mask, position_ids):
        self.inputs_embeds = inputs_embeds
        self.soft_labels = soft_labels
        self.attention_mask = attention_mask[0]
        self.position_ids = position_ids[0]
    

    def __len__(self):
        return len(self.inputs_embeds)

    def __getitem__(self, idx):
        return {
            'inputs_embeds': self.inputs_embeds[idx],
            'labels': self.soft_labels[idx],
            'attention_mask': self.attention_mask,
            'position_ids': self.position_ids
        }
    
def make_datasets(X:torch.FloatTensor, Y:torch.FloatTensor, n_val:int,
                    attention_mask, position_ids) -> Tuple[Dataset, Dataset]:
    

    #make the indices
    idxs = torch.randperm(len(X))
    train_idxs = idxs[:-n_val]

    train_ds = SimpleDataset(X[train_idxs], Y[train_idxs], attention_mask, position_ids)
    valid_ds = SimpleDataset(X[idxs[-n_val:]], Y[idxs[-n_val:]], attention_mask, position_ids)
    return train_ds, valid_ds

traindataset, validdataset = make_datasets(overall_embeddings, overall_out, ft_n_val, attention_mask, position_ids)

In [19]:
len(traindataset), len(validdataset)

(16, 8)

In [22]:


#custom kld loss
def custom_kld_loss(outputs, labels, num_items_in_batch):
    
    logits = outputs['logits'][:,:-1]
    print(logits.shape, labels.shape, num_items_in_batch)
    
    #do kld on the last dim
    loss = torch.nn.KLDivLoss(reduction='sum')(logits, labels)
    if num_items_in_batch == 0 or num_items_in_batch is None:
        l = loss/logits.numel()
        print(l)
        return l
    else:
        print(loss / num_items_in_batch)
        return loss / num_items_in_batch
    
    

In [23]:
#train the model on the dataset with transformers trainer

trainer = transformers.Trainer(
    model=model,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=1,
        gradient_checkpointing=True,
        fp16=True,
        logging_steps=1,
        output_dir="./output",
        num_train_epochs=10,
        save_total_limit=3,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        # eval_steps=100,
        # per_device_train_batch_size=32,
        load_best_model_at_end=True,
        #add a tqdm progress bar
        #set the lr to 1e-5
        learning_rate=1e-4,
        warmup_steps=0,
        dataloader_pin_memory=False,
        #set the logging dir to ./logs
        logging_dir="./logs",
        #log to wandb
        report_to="wandb",
        run_name="llama-2-7b-hf",
    ),
    train_dataset=traindataset,
    eval_dataset=validdataset,
    compute_loss_func=custom_kld_loss,
)

trainer.train()



Could not estimate the number of tokens of the input, floating-point operations will not be computed


torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)


Epoch,Training Loss,Validation Loss
1,-0.0006,-0.000625
2,-0.0006,-0.000654
3,-0.0007,-0.000679
4,-0.0007,-0.000701
5,-0.0007,-0.000719
6,-0.0008,-0.000734


torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(



torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(



torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(



torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0006, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(



torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(



torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0008, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0007, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(-0.0008, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([2, 4095, 32000]) torch.Size([2, 4095, 32000]) tensor(262080000, device='cuda:0')
tensor(

KeyboardInterrupt: 

In [25]:
model.save_pretrained("./output")   

In [19]:
!nvidia-smi

Fri Mar 21 21:09:03 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.216.01             Driver Version: 535.216.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000               Off | 00000000:05:00.0 Off |                  Off |
| 47%   70C    P2             284W / 300W |  48538MiB / 49140MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               Off | 00000000:06:00.0 Off |  

In [24]:
model(valset[0]["input_ids"][...,:-2].cuda(),
      labels=valset[0]["input_ids"][...,1:].cuda()[...,:-1]  # shift labels
      ).loss

tensor(13.3660, device='cuda:0', grad_fn=<ToCopyBackward0>)

In [6]:
import glob
import tqdm
import torch

In [7]:
paths = glob.glob("/data/lliu/huffman/models/meta-llama/*/hessianDiags/seed_0/pajama/128/*/*.pt")
print(len(paths))

1848


In [None]:
for p in tqdm.tqdm(paths):
    hessianDiag = torch.load(p)
    if "hessianDiag" in hessianDiag:
        continue
    torch.save({"hessianDiag": hessianDiag["hessian"]}, p)

100%|██████████| 1848/1848 [00:01<00:00, 979.71it/s] 


: 

In [9]:
torch.load(p)

{'hessianDiag': tensor([0.0067, 0.0076, 0.0071,  ..., 0.0070, 0.0077, 0.0074], device='cuda:1',
        dtype=torch.float16)}

In [None]:
class A:
    
    def __init__(self):
        self.a = 1
        self.b = 2
        self.c = 3
        
    def fn1(self):
        print("fn1_A")
        
    @classmethod
    def fn1_static(cls):
        c = cls()
        c.fn1()
        return c
    
class B(A):
    
    def __init__(self):
        super().__init__()
        print("initializing B")
        self.d = 4
        
    def fn1(self):
        print("fn1_B")
        
B.fn1_static()
    


initializing B
fn1_B


<__main__.B at 0x7a053f3f42f0>

: 