In [None]:
import os
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset

import numpy as np
import math
from torch.nn import CrossEntropyLoss
import csv
import gc
import matplotlib.pyplot as plt
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# For recording logs
# ===========================================================
class LogCallback(TrainerCallback):
    def __init__(self):
        self.epoch_logs = []
        self.final_train_metrics = {}

    def count_trainable_params(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        epoch = state.epoch if state.epoch is not None else 0
        eval_loss = metrics.get("eval_loss")
        perplexity = round(np.exp(eval_loss), 4) if eval_loss is not None else None

        train_loss = None
        for log in reversed(state.log_history):
            if "loss" in log:
                train_loss = log["loss"]
                break
                
        model = kwargs.get("model", None)
        trainable_params = self.count_trainable_params(model) if model is not None else None

        self.epoch_logs.append({
            "epoch": epoch,
            "total_steps": state.global_step,
            "train_loss": train_loss,
            "eval_loss": eval_loss,
            "perplexity": perplexity,
            "trainable_params": trainable_params,
            "avg_kept_energy": None, 
            "samples_per_second": None,
            "steps_per_second": None,
            "total_flos": None,
            "train_runtime": None,
        })

    def on_train_end(self, args, state, control, **kwargs):
        for log in reversed(state.log_history):
            if all(k in log for k in ["train_samples_per_second", "train_steps_per_second", "total_flos", "train_runtime"]):
                self.final_train_metrics = {
                    "samples_per_second": log["train_samples_per_second"],
                    "steps_per_second": log["train_steps_per_second"],
                    "total_flos": log["total_flos"],
                    "train_runtime": log["train_runtime"]
                }
                break

        if self.epoch_logs and self.final_train_metrics:
            self.epoch_logs[-1].update(self.final_train_metrics)


    def record_pruning_energy(self, constraints_callback):
        avg_energy = getattr(constraints_callback, "last_avg_kept_energy", None)
        if avg_energy is None:
            print("[LogCallback] No avg_kept_energy found in constraints_callback.")
            return
    
        if self.epoch_logs:
            self.epoch_logs[-1]["avg_kept_energy"] = avg_energy
            print(f"[LogCallback] Recorded avg_kept_energy = {avg_energy:.4f} to last epoch log.")


                
    def save_to_csv(self, filepath="epoch_metrics.csv"):
        if not self.epoch_logs:
            print("No logs to save.")
            return
    
        file_exists = os.path.isfile(filepath)
        keys = self.epoch_logs[0].keys()
    
        with open(filepath, "a", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=keys)
    
            if not file_exists:
                writer.writeheader()  
    
            writer.writerows(self.epoch_logs)

In [None]:
# For generate text
# ===========================================================
def generate_text(model, tokenizer, prompt, max_new_tokens=20):
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))


In [None]:
# Load the dataset and tokenizer
# ===========================================================
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
train_data = dataset["train"]
val_data = dataset["validation"]

model_dir = './orig_models/llama2-7b-hf' 
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast = False)
tokenizer.pad_token = tokenizer.eos_token  # Set tokenizer padding

# Map token
def tokenize_fn(examples):
    tokenized = tokenizer(examples["text"], truncation = True, max_length = 512)
    return tokenized

tokenized_train = train_data.map(tokenize_fn, batched = True, remove_columns = ["text"])
tokenized_val = val_data.map(tokenize_fn, batched = True, remove_columns = ["text"])
print("Train / Validation mapping: Done\n")

debug_usage = False # Only use a very small part of data for debugging 
if debug_usage is True:
    print("Debugging mode: Only small part of data is loaded.")
    tokenized_train = tokenized_train.select(range(320))  # First x samples
    tokenized_val = tokenized_val.select(range(64))      # First x samples

# Drop input_ids with length ≤ 1
def filter_short(example):
    return len(example["input_ids"]) > 1
print(f"Before dropping: Train ({len(tokenized_train)}) / Val ({len(tokenized_val)})")
tokenized_train = tokenized_train.filter(filter_short)
tokenized_val = tokenized_val.filter(filter_short)
print(f"After dropping: Train ({len(tokenized_train)}) / Val ({len(tokenized_val)})")
print(tokenized_train[0]) #Example of the train token

In [None]:
# Customise the UDV layer for Q, K, V and MLP 
# ===========================================================
class UDV_LoRA(nn.Module):
    def __init__(self, in_dim, out_dim, udv_rank, alpha_scaling):
        super().__init__()     
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.udv_rank = udv_rank
        self.alpha_scaling = alpha_scaling

        # self.W0 = nn.Parameter(torch.empty(self.in_dim, self.out_dim, dtype=torch.float16), requires_grad = False)
        self.register_buffer("W0", torch.empty(self.in_dim, self.out_dim, dtype=torch.float16), persistent = True) # Not trainable but persistent available
        
        self.U = nn.Parameter(torch.empty(self.in_dim, self.udv_rank, dtype=torch.float16), requires_grad = True)
        self.D = nn.Parameter(torch.empty(1, self.udv_rank, dtype=torch.float16), requires_grad = True)
        self.Vt = nn.Parameter(torch.empty(self.udv_rank, self.out_dim, dtype=torch.float16), requires_grad = True)
        with torch.no_grad():
            self.init_parameters()

        self.scaling = self.alpha_scaling / self.udv_rank if self.alpha_scaling != 0 else 1.0

    def init_parameters(self):
        nn.init.normal_(self.U, std = 1e-2)
        nn.init.constant_(self.D, 1.0)
        nn.init.zeros_(self.Vt)

    def set_weight(self, w):
        if w.shape != self.W0.shape:
            raise ValueError(f"Shape mismatch: Expected {self.W0.shape}, got {w.shape}")
        self.W0.copy_(w.to(self.W0.dtype))

     
    def forward(self, x):
        orig_out = torch.matmul(x, self.W0)
        x = torch.matmul(x, self.U)               # [B, *, udv_rank]
        x = torch.mul(x, self.D.view(1, 1, -1))   # [B, *, udv_rank]       
        x = torch.matmul(x, self.Vt)              # [B, *, out_dim]

        return orig_out + x * self.scaling


In [None]:
# (Batch) Replace linear layers by UDV_LoRA
# ===========================================================
def udv_replacement(model, target_keys, udv_rank, alpha_scaling):
    replace_counter = 0
    for name, module in model.named_modules():
        for key in target_keys:
            if name.endswith(key):
                path = name.split(".")
                parent = model
                for p in path[:-1]:
                    if hasattr(parent, p):
                        parent = getattr(parent, p)
                    elif p.isdigit():
                        parent = parent[int(p)]
                    else:
                        raise RuntimeError(f"Cannot resolve submodule path: {'.'.join(path)}")

                old_linear = getattr(parent, path[-1])
                if not isinstance(old_linear, nn.Linear):
                    continue

                new_layer = UDV_LoRA(old_linear.in_features, old_linear.out_features, udv_rank, alpha_scaling)
                with torch.no_grad():
                    new_layer.set_weight(old_linear.weight.detach().clone().T)
                setattr(parent, path[-1], new_layer)
                
                assert torch.allclose(old_linear.weight.detach().T.to(device), new_layer.W0.detach().half().to(device), atol=1e-5)      
                replace_counter = replace_counter + 1
                
    print(f"Replaced {replace_counter} layers in total.")
    return model


In [None]:
# Constraints
# ===========================================================
class UDV_UVbound(nn.Module):
    def __init__(self, uv_norm_limit):
        super().__init__()
        self.uv_norm_limit = uv_norm_limit

    def forward(self, udv_uv):
        norm_sq = torch.linalg.matrix_norm(udv_uv, ord='fro') ** 2
        if norm_sq > self.uv_norm_limit:
            udv_uv.div_(torch.sqrt(norm_sq)) 
        return udv_uv

class UDV_Dbound(nn.Module):
    def __init__(self, d_lowerbound, d_boundto):
        super().__init__()
        self.d_lowerbound = d_lowerbound
        self.d_boundto = d_boundto

    def forward(self, udv_d):
        udv_d[udv_d < self.d_lowerbound] = self.d_boundto
        return udv_d

def check_udv_constraints(udv_param_groups, uv_norm_limit, d_lowerbound):
    failed_layers = []
    scale_factor = 1.001 # To aviod numerical issue by scaling the norm 

    for layer_name, udv in udv_param_groups:
        U = udv['U'].detach().float()
        D = udv['D'].detach().float()
        Vt = udv['Vt'].detach().float()

        u_norm = torch.linalg.matrix_norm(U, ord='fro').item()
        v_norm = torch.linalg.matrix_norm(Vt, ord='fro').item()
        d_min = D.min().item()

        u_ok = u_norm <= (uv_norm_limit * scale_factor)
        v_ok = v_norm <= (uv_norm_limit * scale_factor)
        d_ok = d_min >= d_lowerbound

        if not (u_ok and v_ok and d_ok):
            failed_layers.append({
                "layer": layer_name,
                "u_norm": u_norm,
                "v_norm": v_norm,
                "d_min": d_min,
                "u_ok": u_ok,
                "v_ok": v_ok,
                "d_ok": d_ok,
            })

    if failed_layers:
        print(f"[Constraint Check] Constraint violation detected (scale factor [*{scale_factor}]):")
        for item in failed_layers:
            print(f"  Layer: {item['layer']}")
            print(f"    U Fro norm = {item['u_norm']:.4f} | OK: {item['u_ok']}")
            print(f"    V Fro norm = {item['v_norm']:.4f} | OK: {item['v_ok']}")
            print(f"    D min val  = {item['d_min']:.4f} | OK: {item['d_ok']}")
    else:
        print(f"[Constraint Check] All UDV layers satisfy the constraints(scale factor [*{scale_factor}]).")


# ===========================================================        
class Constraints_CallBack(TrainerCallback):
    def __init__(self, uv_norm_limit, d_lowerbound, d_boundto, pruning_energy_threshold):
        self.uv_bound = UDV_UVbound(uv_norm_limit)
        self.d_bound = UDV_Dbound(d_lowerbound, d_boundto)
        self.udv_param_groups = []  # List of (module_name, {'U': ..., 'D': ..., 'Vt': ...})
        self.original_params = {}   # {layer_name: {'U': tensor, 'D': tensor, 'Vt': tensor}}
        self.svd_record = {}        # {layer_name: [list of singular values per step]}
        self.pruning_energy_threshold = pruning_energy_threshold
        self.initialized = False

    def _find_udv_groups(self,model):
        udv_groups = []
        for module_name, module in model.named_modules():
            if all(hasattr(module, name) for name in ['U', 'D', 'Vt']):
                U, D, Vt = module.U, module.D, module.Vt
                if all(isinstance(p, torch.nn.Parameter) for p in [U, D, Vt]):
                    udv_groups.append((module_name, {'U': U, 'D': D, 'Vt': Vt}))
        return udv_groups
    
    def on_step_end(self, args, state, control, **kwargs):
        model = kwargs["model"]

        if not self.initialized:
            self.udv_param_groups = self._find_udv_groups(model)
            self.initialized = True

        if state.global_step % args.gradient_accumulation_steps == 0:
            self.apply_constraints()
            self.record_svd(step = state.global_step)

        return control

    def apply_constraints(self):
        with torch.no_grad():
            for name, param_dict in self.udv_param_groups:
                U = param_dict['U']
                D = param_dict['D']
                Vt = param_dict['Vt']
                
                U.data.copy_(self.uv_bound(U.data))
                Vt.data.copy_(self.uv_bound(Vt.data))
                D.data.copy_(self.d_bound(D.data))

    def record_svd(self, step):
        with torch.no_grad():
            for layer_name, udv in self.udv_param_groups:
                U = udv['U']
                D = udv['D']
                try:
                    uv_matrix = torch.mul(udv['U'], udv['D']).float() 
                    U_svd, S_svd, Vh_svd = torch.linalg.svd(uv_matrix, full_matrices = False)

                    if layer_name not in self.svd_record:
                        self.svd_record[layer_name] = []

                    self.svd_record[layer_name].append({"step": step,
                                                        "svd": S_svd.cpu().numpy()
                                                       })
                except Exception as e:
                    print(f"[Warning] SVD failed on layer {layer_name}: {e}")

    def plot_svd(self, layers=None, steps=None, normal_layers=None, show=True, save_path=None, model=None):
    
        plt.figure(figsize=(6, 4))
        layers_to_plot = layers or list(self.svd_record.keys())
    
        # ---- Plot UDV layers
        for layer in layers_to_plot:
            records = self.svd_record.get(layer, [])
            if not records:
                print(f"[plot_svd] No SVD records found for UDV layer: {layer}")
                continue
    
            for rec in records:
                step = rec["step"]
                if steps is not None and step not in steps:
                    continue
                label = f"{layer} (step {step})"
                plt.plot(rec["svd"], label=label)
    
        # ---- Plot normal layers (final weights only)
        if normal_layers and model is not None:
            for layer_name in normal_layers:
                try:
                    module = model
                    for attr in layer_name.split('.'):
                        if attr.isdigit():
                            module = module[int(attr)]
                        else:
                            module = getattr(module, attr)
                    weight = module.weight.float()
                    _, S_svd, _ = torch.linalg.svd(weight, full_matrices=False)
                    label = f"{layer_name} (final)"
                    plt.plot(S_svd.cpu().numpy(), label=label, linestyle='--')
                except Exception as e:
                    print(f"[plot_svd] Failed to get SVD of {layer_name}: {e}")
    
        plt.title("SVD Spectra - UDV and Normal Layers")
        plt.xlabel("Singular Value Index")
        plt.ylabel("Singular Value (log scale)")
        plt.yscale("log")
        plt.legend(fontsize="small", loc="best")
        plt.grid(True)
    
        if save_path:
            os.makedirs(save_path, exist_ok=True)
            plt.savefig(os.path.join(save_path, "svd_spectra.png"))
    
        if show:
            plt.show()
    
        plt.close()


    def replace_para(self, optimizer, layer, name, new_param, transfer_state = False):
        old_param = getattr(layer, name)
        new_param = nn.Parameter(new_param)
        setattr(layer, name, new_param)

        for group in optimizer.param_groups:
            for i, p in enumerate(group['params']):
                if p is old_param:
                    group['params'][i] = new_param
        if old_param in optimizer.state:
            if transfer_state:
                optimizer.state[new_param] = optimizer.state.pop(old_param)
            else:
                del optimizer.state[old_param]

    def auto_svd_pruning(self, model, optimizer, energy_threshold):
        energy_threshold = energy_threshold or self.pruning_energy_threshold
        kept_energies, pruned_layers, retained_layers = [], 0, 0
        svd_record = {}

        with torch.no_grad():
            for layer_name, udv in self.udv_param_groups:
                layer = dict(model.named_modules())[layer_name]
                try:
                    uv_matrix = torch.mul(udv['U'], udv['D']).float()
                    U_svd, S_svd, Vh_svd = torch.linalg.svd(uv_matrix, full_matrices = False)
                except Exception as e:
                    print(f"[Prune] Skipped {layer_name}: {e}")
                    continue

                svd_record[layer_name] = S_svd.cpu().numpy()
                energy = torch.cumsum(S_svd ** 2, dim=0)
                total_energy = energy[-1]
                keep_rank = torch.searchsorted(energy, energy_threshold * total_energy).item() + 1
                kept_energy = energy[keep_rank - 1] / total_energy
                kept_energies.append(kept_energy.item())
                pruned_layers += 1

                if keep_rank == S_svd.numel():
                    retained_layers += 1
                    continue

                self.original_params[layer_name] = {"U": udv['U'].detach().cpu(),
                                                    "D": udv['D'].detach().cpu(),
                                                    "Vt": udv['Vt'].detach().cpu()}
                
                self.replace_para(optimizer, layer, 'U', U_svd[:, :keep_rank].half())
                self.replace_para(optimizer, layer, 'D', S_svd[:keep_rank].unsqueeze(0).half())
                self.replace_para(optimizer, layer, 'Vt', torch.matmul(Vh_svd[:keep_rank, :], udv['Vt'].float()).half())
        avg_kept_energy = sum(kept_energies) / pruned_layers if pruned_layers > 0 else None
        print(f"[UDV Prune] Avg Kept Energy: {avg_kept_energy * 100:.4f}% over "
              f"{pruned_layers} layers. {retained_layers} layers retained full rank.")

        return svd_record, avg_kept_energy

    def recover_pruning(self, model, optimizer):
        if not hasattr(self, "original_params"):
            print("[Recover] No saved parameters found.")
            return
    
        print(f"[Recover] Restoring {len(self.original_params)} layers to pre-pruned state...")
        for layer_name, params in self.original_params.items():
            layer = dict(model.named_modules())[layer_name]
            self.replace_para(optimizer, layer, "U", params["U"].to(layer.U.device), transfer_state=False)
            self.replace_para(optimizer, layer, "D", params["D"].to(layer.D.device), transfer_state=False)
            self.replace_para(optimizer, layer, "Vt", params["Vt"].to(layer.Vt.device), transfer_state=False)
    
        print("[Recover] All pruned layers restored.")


    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer

    def on_train_end(self, args, state, control, **kwargs):
        print("Pruning is not applied here. Save fine-tuned model first then test pruning.")
        
    
    def post_pruning(self, model, optimizer, energy_threshold):
        print(f"[Post Pruning] Starting manual pruning with energy threshold = {energy_threshold}...")
        # self.energy_threshold = energy_threshold
        self.pruning_record, self.last_avg_kept_energy = self.auto_svd_pruning(model, optimizer, energy_threshold)
        print(f"[Post Pruning] Completed. Avg kept energy: {self.last_avg_kept_energy:.4f}")




In [None]:
udv_lora_r = 64
udv_lora_learning_rate = 1e-3
udv_lora_num_train_epochs = 30
udv_lora_output_path = f"./udv_lora_r_{udv_lora_r}_SGD_{udv_lora_learning_rate}"
csv_path = udv_lora_output_path + "/epoch_metrics.csv"

In [None]:
# Re-load pre-trained model and replace linear layers by UDV_LoRA
# ===========================================================
udv_lora_model = AutoModelForCausalLM.from_pretrained(model_dir,
                                                      device_map = "auto",
                                                      torch_dtype = torch.float16)

for param in udv_lora_model.parameters():
    param.requires_grad = False

TARGET_KEYS = ["self_attn.q_proj",
               # "self_attn.k_proj",
               "self_attn.v_proj",
               # "self_attn.o_proj",
               # "mlp.up_proj",
               # "mlp.gate_proj",
               # "mlp.down_proj",
               # "lm_head"
              ]

# Replace layers and send to device
udv_lora_model = udv_replacement(model = udv_lora_model,
                                 target_keys = TARGET_KEYS,
                                 udv_rank = udv_lora_r,
                                 alpha_scaling = 0)
udv_lora_model.to(device)

# ===========================================================


data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
udv_training_args = TrainingArguments(output_dir = udv_lora_output_path,
                                      per_device_train_batch_size = 8, 
                                      per_device_eval_batch_size = 8, 
                                      gradient_accumulation_steps = 4,
                                      fp16 = False, 
                                      logging_nan_inf_filter = True, 
                                      seed = 42,
                                      data_seed = 42,

                                      optim = 'sgd',
                                      learning_rate = udv_lora_learning_rate, 
                                      weight_decay = 0, 
                                      adam_beta1 = 0.9,
                                      adam_beta2 = 0.999, 
                                      adam_epsilon = 1e-8,
                                      max_grad_norm = 0, 
                                      lr_scheduler_type = "cosine",
    
                                      num_train_epochs = udv_lora_num_train_epochs, 
                                      eval_strategy = "epoch", 

                                      save_strategy = "no",
                                      save_total_limit = 1, 
                                      load_best_model_at_end = False,
                                      metric_for_best_model = "eval_loss",
                                      greater_is_better = False,

                                      logging_strategy = "epoch",
                                      report_to = ["none"], 
    
                                      label_names = ["labels"])

# ===========================================================
log_callback = LogCallback()
constraint_callback = Constraints_CallBack(uv_norm_limit = 1,
                                           d_lowerbound = 0,
                                           d_boundto = 0,
                                           pruning_energy_threshold = 1)
constraint_callback.set_tokenizer(tokenizer)


udv_trainer = Trainer(model = udv_lora_model,
                      args = udv_training_args,
                      train_dataset = tokenized_train,
                      eval_dataset = tokenized_val,
                      data_collator = data_collator,
                      callbacks = [log_callback, constraint_callback]
                     )

udv_lora_output = udv_trainer.evaluate() # Inference once to create udv_lora baseline


In [None]:
# Evaluation: Fine-tuning with UDV-LoRA
# ===========================================================
udv_lora_output = udv_trainer.train() # Fine-tuning with inference
check_udv_constraints(constraint_callback.udv_param_groups,
                      uv_norm_limit = constraint_callback.uv_bound.uv_norm_limit,
                      d_lowerbound = constraint_callback.d_bound.d_lowerbound)

log_callback.save_to_csv(csv_path)


In [None]:
energy_list = [round(x, 1) for x in torch.arange(1.0, 0.0, -0.05).tolist()]

for energy in energy_list:
    constraint_callback.recover_pruning(model=udv_trainer.model,
                                        optimizer=udv_trainer.optimizer)

    constraint_callback.post_pruning(model=udv_trainer.model,
                                     optimizer=udv_trainer.optimizer, 
                                     energy_threshold=energy)

    after_pruning_log = LogCallback()
    
    udv_lora_output = udv_trainer.evaluate()
    after_pruning_log.on_evaluate(
        args=udv_trainer.args,
        state=udv_trainer.state,
        control=udv_trainer.control,
        metrics=udv_lora_output,
        model=udv_trainer.model
    )
    
    after_pruning_log.record_pruning_energy(constraint_callback)
    after_pruning_log.save_to_csv(csv_path)


In [None]:
with open(udv_lora_output_path + "/svd_dict.pkl", "wb") as f:
    pickle.dump(constraint_callback.svd_record, f)


In [None]:
print("all done")