In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import math
from einops import rearrange, einsum
from torch.nn import functional as F

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
pre_act = np.load('../Llama_7B/pre_act_layer31_mlp_down.npy')

pre_act = pre_act.reshape(-1, pre_act.shape[-1]).T

print(pre_act.shape)
weights = np.load('../Llama_7B/weight_layer31_mlp_down.npy')
print(weights.shape)

(11008, 18432)
(4096, 11008)


In [3]:
#Turn the data into torch tensors
pre_act = torch.from_numpy(pre_act).float().to(DEVICE)
weights = torch.from_numpy(weights).float().to(DEVICE)

In [4]:
ratio = 0.6 # 60% of the weights
out_features, in_features = weights.shape
        
if in_features == out_features:
    truncate = math.ceil(ratio * in_features / 2)
else:
    truncate = math.ceil((ratio * in_features * out_features) / (in_features + out_features))
    
# truncate = math.ceil(min(in_features, out_features) * ratio)

print(f"Truncating {truncate} weights")

Truncating 1792 weights


In [12]:
ratio = 0.6 # 60% of the weights
out_features, in_features = weights.shape
        
if in_features == out_features:
    truncate = int(ratio * in_features / 2)
else:
    truncate = int((ratio * in_features * out_features) / (in_features + out_features))
    
# truncate = math.ceil(min(in_features, out_features) * ratio)

print(f"Truncating {truncate} weights")

Truncating 1228 weights


In [13]:
from utils import get_truncate

ratio = 0.6
out_features, in_features = weights.shape

truncate = get_truncate(in_features, out_features, ratio)
print(f"Truncating {truncate} weights")

Truncating 1229 weights


In [14]:
ratio = 0.6 # 60% of the weights
out_features, in_features = weights.shape
        

truncate = int((ratio * in_features * out_features) / (in_features + out_features))
    
# truncate = math.ceil(min(in_features, out_features) * ratio)

print(f"Truncating {truncate} weights")

Truncating 1228 weights


In [5]:
act_real = torch.matmul(weights, pre_act)
print(act_real.shape)

torch.Size([4096, 18432])


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelScaling(nn.Module):
    """
    Batch Normalization without mean and beta.
    This is a custom implementation that normalizes the input by dividing by the standard deviation
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.register_buffer("running_var", torch.ones(num_features))

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, num_features)
        Returns:
            Tensor of shape (batch_size, seq_len, num_features)
        """
        
        if self.training:
            var = x.var(dim=(0, 1), unbiased=False)# Calculate variance over batch and sequence dimensions
            self.running_var.data.copy_ = self.momentum * var + (1 - self.momentum) * self.running_var # Update running variance
        else:
            var = self.running_var # Use running variance during inference

        x_norm = x / torch.sqrt(var.view(1, 1, -1) + self.eps) # Normalize the input by dividing by the standard deviation
        return self.gamma.view(1, 1, -1) * x_norm  # Scale the normalized input by gamma
        

class SVDLinearLayer(nn.Module):
    def __init__(self, weights, truncate, bias=None, data=None):
        super(SVDLinearLayer, self).__init__()
        
        torch.cuda.empty_cache()
        
        device = weights.device
        
        out_features, in_features = weights.shape
        
        if data.dim() == 2: # Add batch dimension if missing.
            data = data.unsqueeze(0).to(device)
            
        self.normalization1 = ChannelScaling(weights.shape[1]).to(device)
        data_var = data.var(dim=(0, 1), unbiased=False).to(device)
        self.normalization1.running_var.data.copy_(data_var)
        
        # Compute normalization factors for the weights
        diag_norm1 = torch.diag(self.normalization1.gamma / torch.sqrt(self.normalization1.running_var + self.normalization1.eps)).to(device)
        
        weights = torch.matmul(weights, torch.inverse(diag_norm1))
        
        # #Perform SVD on the weights
        # U, S, Vt = torch.linalg.svd(weights, full_matrices=False)
        # U = U[:, :truncate]
        # S = S[:truncate]
        # Vt = Vt[:truncate, :]
        
        U, S, V = torch.svd_lowrank(weights, q=truncate, niter=1)
        Vt = V.t()
        
        diag_s = torch.diag(torch.sqrt(S))
        
        vt_parameter = torch.matmul(diag_s, Vt)
        
        self.vt_linear = nn.Linear(in_features, truncate, bias=False)
        self.vt_linear.weight.data.copy_(vt_parameter)
        
        self.normalization2 = ChannelScaling(U.shape[1]).to(device)
        
        self.normalization1.eval()
        data_var = F.linear(self.normalization1(data), vt_parameter, bias=None).var(dim=(0, 1), unbiased=False).to(device)
        self.normalization1.train()
        self.normalization2.running_var.data.copy_(data_var)
        diag_norm2 = torch.diag(self.normalization2.gamma / torch.sqrt(self.normalization2.running_var + self.normalization2.eps)).to(device)
        
        u_parameter = torch.matmul(U, torch.matmul(diag_s, torch.inverse(diag_norm2)))
            
        self.bias = bias
        
        self.u_linear = nn.Linear(truncate, out_features, bias=True if bias is not None else False)
        if bias is not None:
            self.u_linear.bias.data.copy_(bias)
        self.u_linear.weight.data.copy_(u_parameter)
        
        del weights, U, S, Vt, diag_s, diag_norm1, diag_norm2, data_var, data, device, u_parameter, vt_parameter
        
        torch.cuda.empty_cache()
        
    def forward(self, x):
        
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, in_features) or (seq_len, in_features)
        Returns:
            Tensor of shape (batch_size, seq_len, out_features)
        """
        
        if x.dim() == 2: # Add batch dimension if missing.
            x = x.unsqueeze(0)
            
        x = self.normalization1(x)
        
        x = self.vt_linear(x)
        
        x = self.normalization2(x) 
        
        x = self.u_linear(x)
        
        return x
        
    
    def reconstruct_weights(self):
        """
        Reconstruct the effective weight matrix, taking into account the normalization layers.
        """
        device = self.vt_linear.weight.device
        
        # Incorporate normalization factors from normalization1.
        if isinstance(self.normalization1, ChannelScaling):
            norm1 = self.normalization1
            diag_norm1 = torch.diag(norm1.gamma / torch.sqrt(norm1.running_var + norm1.eps))
        else:
            diag_norm1 = torch.eye(self.vt_linear.weight.shape[1], device=device)
            
        # Incorporate normalization factors from normalization2.
        if isinstance(self.normalization2, ChannelScaling):
            norm2 = self.normalization2
            diag_norm2 = torch.diag(norm2.gamma / torch.sqrt(norm2.running_var + norm2.eps))
        else:
            diag_norm2 = torch.eye(self.u_linear.weight.shape[1], device=device)
        
        # Reconstruct the weight matrix.
        return self.u_linear.weight @ diag_norm2 @ self.vt_linear.weight @ diag_norm1

In [7]:
new_module = SVDLinearLayer(weights, truncate, None, pre_act.T).to(DEVICE)

In [8]:
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

ITERS = 500
losses = []
new_module.train()

optimizer = torch.optim.AdamW(new_module.parameters(), lr=0.0001)

pbar = tqdm(range(ITERS))
for i in pbar:
    optimizer.zero_grad()
    
    aprox_act = new_module(pre_act.T).squeeze().T
    
    loss = torch.norm(aprox_act - act_real, p='fro')
    losses.append(loss.item())
    
    loss.backward()
    
    optimizer.step()
    
    pbar.set_postfix({'loss': loss.item()})

100%|██████████| 500/500 [01:24<00:00,  5.93it/s, loss=1.77e+3]


In [9]:
with torch.no_grad():
    new_module.eval()
    aprox_act = new_module(pre_act.T).squeeze().T
    print("loss: ", torch.norm(aprox_act - act_real, p='fro').detach().cpu().numpy())

loss:  1872.4783


In [10]:
print(new_module)

SVDLinearLayer(
  (normalization1): ChannelScaling()
  (vt_linear): Linear(in_features=11008, out_features=1792, bias=False)
  (normalization2): ChannelScaling()
  (u_linear): Linear(in_features=1792, out_features=4096, bias=False)
)


In [20]:
new_module = SVDLinearLayer(weights, truncate, None, pre_act.T).to(DEVICE)

In [21]:
from tqdm import tqdm

# Training vt_linear with normalization1
new_module.train()
for param in new_module.parameters():
    param.requires_grad = False
for param in new_module.vt_linear.parameters():
    param.requires_grad = True
for param in new_module.normalization1.parameters():
    param.requires_grad = True

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, new_module.parameters()), lr=0.0001)
losses = []

pbar = tqdm(range(ITERS), desc="Training vt_linear with normalization1")
for i in pbar:
    optimizer.zero_grad()
    
    aprox_act = new_module(pre_act.T).squeeze().T
    loss = torch.norm(aprox_act - act_real, p='fro')
    losses.append(loss.item())
    
    loss.backward()
    optimizer.step()
    pbar.set_postfix({'loss': loss.item()})

# Training u_linear with normalization2
new_module.train()
for param in new_module.parameters():
    param.requires_grad = False
for param in new_module.u_linear.parameters():
    param.requires_grad = True
for param in new_module.normalization2.parameters():
    param.requires_grad = True

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, new_module.parameters()), lr=0.0001)

pbar = tqdm(range(ITERS), desc="Training u_linear with normalization2")
for i in pbar:
    optimizer.zero_grad()
    
    aprox_act = new_module(pre_act.T).squeeze().T
    loss = torch.norm(aprox_act - act_real, p='fro')
    losses.append(loss.item())
    
    loss.backward()
    optimizer.step()
    pbar.set_postfix({'loss': loss.item()})

# Training everything together
new_module.train()
for param in new_module.parameters():
    param.requires_grad = True

optimizer = torch.optim.Adam(new_module.parameters(), lr=0.0001)

pbar = tqdm(range(ITERS), desc="Training everything")
for i in pbar:
    optimizer.zero_grad()
    
    aprox_act = new_module(pre_act.T).squeeze().T
    loss = torch.norm(aprox_act - act_real, p='fro')
    losses.append(loss.item())
    
    loss.backward()
    optimizer.step()
    pbar.set_postfix({'loss': loss.item()})

Training vt_linear with normalization1: 100%|██████████| 500/500 [01:17<00:00,  6.47it/s, loss=2.82e+3]
Training u_linear with normalization2: 100%|██████████| 500/500 [00:46<00:00, 10.67it/s, loss=2.04e+3]
Training everything: 100%|██████████| 500/500 [01:24<00:00,  5.95it/s, loss=1.76e+3]


In [23]:
with torch.no_grad():
    new_module.train()
    aprox_act = new_module(pre_act.T).squeeze().T
    print("loss: ", torch.norm(aprox_act - act_real, p='fro').detach().cpu().numpy())

loss:  1759.162


In [43]:
def get_ratios(final_ratio, matrix_iters):
    ratios = []
    for i in range(1, matrix_iters + 1):
        r = 1 - i * (1 - final_ratio) / matrix_iters
        ratios.append(r)
    return ratios

def get_truncate(in_features, out_features, ratio):
    if in_features == out_features:
        return math.ceil(ratio * in_features / 2)
    else:
        return math.ceil((ratio * in_features * out_features) / (in_features + out_features))

GRADIENT_ITERS = 500
MATRIX_ITERS = 2
FINAL_RATIO = 0.6

ratios = get_ratios(FINAL_RATIO, MATRIX_ITERS)

new_weights = weights.clone()
new_weights = new_weights.to(DEVICE)

for ratio in ratios:
    in_features, out_features = new_weights.shape
    truncate = get_truncate(in_features, out_features, ratio)
    
    print(f"Truncating {truncate} weights for ratio {ratio}")
    
    new_module = SVDLinearLayer(new_weights, truncate, None, pre_act.T).to(DEVICE)
    new_module.train()
    optimizer = torch.optim.AdamW(new_module.parameters(), lr=0.0001)
    
    pbar = tqdm(range(GRADIENT_ITERS))
    for i in pbar:
        optimizer.zero_grad()

        aprox_act = new_module(pre_act.T).squeeze().T

        loss = torch.norm(aprox_act - act_real, p='fro')
        losses.append(loss.item())

        loss.backward()

        optimizer.step()

        pbar.set_postfix({'loss': loss.item()})
        
    new_module.eval()
    with torch.no_grad():
        new_weights = new_module.reconstruct_weights()
        new_weights = new_weights.to(DEVICE)
    

Truncating 2389 weights for ratio 0.8


100%|██████████| 500/500 [00:50<00:00,  9.94it/s, loss=1.25e+3]


Truncating 1792 weights for ratio 0.6


100%|██████████| 500/500 [00:38<00:00, 12.93it/s, loss=1.77e+3]


In [44]:
with torch.no_grad():
    new_module.eval()
    aprox_act = new_module(pre_act.T).squeeze().T
    print("loss: ", torch.norm(aprox_act - act_real, p='fro').detach().cpu().numpy())

loss:  1823.6628


In [1]:
from svdmodels import SVDModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from utils import load_wikitext

model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b",torch_dtype=torch.float16)
SEQ_LEN = model.config.max_position_embeddings
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
model = SVDModel.load_model(model, ratio=0.8, model_path="results/llama-7b/gsvd_llama-7b_r0.8_g500_c256_m1.pt")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.half()
model.to(DEVICE)
model.eval()

with torch.no_grad():
    BATCH_SIZE = 8
    
    loader = load_wikitext(tokenizer,
                           seq_len=SEQ_LEN,
                            batch_size=BATCH_SIZE)
    
    nlls = []
    for batch in tqdm(loader, desc="Evaluating", total=len(loader)):
        batch = batch.to(DEVICE)
        logits = model(input_ids=batch, use_cache=False).logits
        if torch.isfinite(logits).all():
            shited_logits = logits[:, :-1, :].contiguous()
            shifted_labels = batch[:, 1:].contiguous()
            loss_fnc = torch.nn.CrossEntropyLoss(reduction='none')
            loss = loss_fnc(shited_logits.view(-1, logits.size(-1)), shifted_labels.view(-1))
            nlls.append(loss.cpu())
        else:
            print("Non-finite logits detected, skipping batch.")
            continue
        
    mean_loss = torch.cat(nlls).mean()
    ppl = torch.exp(mean_loss).item()
    if ppl > 1000:
        ppl = int(ppl)
        
    print(f"Perplexity: {ppl}")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 10.14it/s]
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
Replacing modules: 100%|██████████| 423/423 [00:20<00:00, 20.49it/s]


Skipping lm_head


Token indices sequence length is longer than the specified maximum sequence length for this model (341469 > 2048). Running this sequence through the model will result in indexing errors
Evaluating: 100%|██████████| 21/21 [00:42<00:00,  2.00s/it]

Perplexity: 8.5546875



