In [9]:
import math
math.ceil(float(554.35))

555

In [1]:
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import (replace_module_by_name, get_calib_train_data,
                   get_truncate, get_unique_path, load_eval_tokenized_dataset)
from modules import SVDLinearLayer, WeightedMSELoss, HybridLoss
from tqdm import tqdm
import os
import argparse
import csv
import gc # --- CHANGE: Import garbage collector

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
RATIO = 0.6
CALIB_DATASET = "wikitext2"
model_name = "huggyllama/llama-7b"

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
SEQLEN = model.config.max_position_embeddings
BATCH_SIZE = 4
CALIB_SAMPLES = 256
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

calib_data = get_calib_train_data(name=CALIB_DATASET,
                                  tokenizer=tokenizer,
                                  seqlen=SEQLEN,
                                  batch_size=BATCH_SIZE,
                                  nsamples=CALIB_SAMPLES,
                                  seed=SEED)
print("Calibration data loaded.")

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  7.83it/s]
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


Loading cached dataset from cache/wikitext2_256_2048_42_dict.pt
Calibration data loaded.


In [3]:
# Securely get activation function from config, fallback to nn.Identity if not present
from transformers.activations import ACT2FN
activation_fn = ACT2FN[model.config.hidden_act] if hasattr(model.config, "hidden_act") and model.config.hidden_act in ACT2FN else nn.Identity()
print(f"Activation Function: {getattr(activation_fn, '__name__', activation_fn.__class__.__name__)}")

# Securely get sequence length from config, fallback to 2048 if not present
SEQLEN = getattr(model.config, "max_position_embeddings", 2048)
print(f"Sequence Length: {SEQLEN}")

Activation Function: SiLU
Sequence Length: 2048


In [None]:
GRADIENT_EPOCHS = 10
CALIB_BATCH_SIZE = 16

linear_layer_names = []
for name, module in model.named_modules():
    if isinstance(module, nn.Linear) and name != "lm_head":
        linear_layer_names.append(name)

target_layers = [
    "model.layers.3.mlp.gate_proj",
    "model.layers.3.mlp.up_proj"]

completed_layers = []

for name in tqdm(reversed(linear_layer_names), desc="Modules", total=len(linear_layer_names)):
    if name in target_layers:
        if any(name.startswith(completed_name) for completed_name in completed_layers):
            print(f"Skipping already processed layer: {name}")
            continue
        if "up_proj" in name:
            
            activations = []
            def forward_hook(module, input, output):
                activations.append(input[0].detach().cpu())

            hook = module.register_forward_hook(forward_hook)
            model.eval()
            model.half()
            model.to(DEVICE)
            with torch.no_grad():
                for batch in tqdm(calib_data, desc=f"Getting activations for {name}"):
                    input_ids = batch['input_ids'].to(DEVICE)
                    attention_mask = batch['attention_mask'].to(DEVICE)
                    model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
            hook.remove()

            # --- CHANGE: Aggressively clear VRAM after activation gathering ---
            model.to('cpu')
            gc.collect()
            torch.cuda.empty_cache()

            activations = torch.cat(activations, dim=0).float()

            module.to(DEVICE).float()

            W = module.weight.data.clone().detach().to(DEVICE)
            b = module.bias.data.clone().detach().to(DEVICE) if module.bias is not None else None

            out_features, in_features = W.shape
            low_rank = get_truncate(in_features, out_features, RATIO)
            print(f"Applying GSVD to {name} with low rank: {low_rank}")

            U, S, VT = torch.linalg.svd(W, full_matrices=False)
            s_sqrt = torch.diag(torch.sqrt(S))
            u_parameter = torch.matmul(U[:, :low_rank], s_sqrt[:low_rank, :low_rank])
            vt_parameter = torch.matmul(s_sqrt[:low_rank, :low_rank], VT[:low_rank, :])

            svd_layer = SVDLinearLayer(
                vt_parameter=vt_parameter,
                u_parameter=u_parameter,
                bias=b,
            )
            svd_layer.to(DEVICE).train()
            
            # If the layer is a mlp.gate_proj let's calibrate it with the activation function to ensure a more accurate representation
            if 'mlp.gate_proj' in name:
                activation = activation_fn
            else:
                activation = nn.Identity()

            del W, U, S, VT, s_sqrt, u_parameter, vt_parameter, b
            torch.cuda.empty_cache()
            
            importance_weights = grads[name] if GRADS_PATH is not None else None
            #Ensure no 0 importance weights are passed, if they are, use a threshold of 0.01
            if importance_weights is not None:
                importance_weights = torch.where(importance_weights < 0.01, torch.tensor(0.01, device=importance_weights.device), importance_weights)

            loss_fn = HybridLoss(weights=importance_weights,
                                alpha=0.5,  # Adjust alpha as needed
                                reduction="mean").to(DEVICE)
            
            optimizer = torch.optim.AdamW(svd_layer.parameters(), lr=1e-4, weight_decay=1e-3)

            print(f"Training GSVD layer for {name}...")
            activation_dataset = torch.utils.data.TensorDataset(activations)
            activation_loader = torch.utils.data.DataLoader(
                activation_dataset,
                batch_size=CALIB_BATCH_SIZE,
                shuffle=True
            )
            
            loss_log = []
            for epoch in range(GRADIENT_EPOCHS):
                epoch_loss_log = []
                pbar = tqdm(activation_loader, desc=f"Epoch {epoch+1}/{GRADIENT_EPOCHS}", leave=False)

                for batch_tuple in pbar:
                    batch = batch_tuple[0].to(DEVICE)
                    optimizer.zero_grad()
                    output = activation(svd_layer(batch))
                    # The target needs to be computed on the GPU with the original module
                    with torch.no_grad():
                        target = activation(module(batch))
                    loss = loss_fn(output, target)
                    loss.backward()
                    optimizer.step()
                    epoch_loss_log.append(loss.item())
                    pbar.set_postfix({'loss': loss.item()})
                loss_log.append(sum(epoch_loss_log) / len(epoch_loss_log))

            print(f"Finished training GSVD layer for {name}. Loss: {loss_log[-1]}")
            
            svd_layer.to('cpu').eval().half()
            replace_module_by_name(model, name, svd_layer)
            
            # --- CHANGE: Clean up and save checkpoint ---
            del svd_layer, activations, loss_fn, optimizer, module
            gc.collect()
            torch.cuda.empty_cache()


Modules:   0%|          | 0/224 [00:00<?, ?it/s]

model.layers.3.mlp.up_proj


Getting activations for model.layers.3.mlp.up_proj:   6%|▋         | 4/64 [00:04<01:00,  1.02s/it]
Modules:  88%|████████▊ | 197/224 [00:05<00:00, 33.65it/s]


KeyboardInterrupt: 