In [1]:
!pip install transformers datasets tqdm



In [2]:
from transformers import AutoModel, AutoTokenizer, BertForSequenceClassification, BertTokenizer
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from datasets import load_dataset

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


In [3]:
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2", num_labels=2)
model.to("cuda")

for name, module in model.named_modules():
    print(name)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



bert
bert.embeddings
bert.embeddings.word_embeddings
bert.embeddings.position_embeddings
bert.embeddings.token_type_embeddings
bert.embeddings.LayerNorm
bert.embeddings.dropout
bert.encoder
bert.encoder.layer
bert.encoder.layer.0
bert.encoder.layer.0.attention
bert.encoder.layer.0.attention.self
bert.encoder.layer.0.attention.self.query
bert.encoder.layer.0.attention.self.key
bert.encoder.layer.0.attention.self.value
bert.encoder.layer.0.attention.self.dropout
bert.encoder.layer.0.attention.output
bert.encoder.layer.0.attention.output.dense
bert.encoder.layer.0.attention.output.LayerNorm
bert.encoder.layer.0.attention.output.dropout
bert.encoder.layer.0.intermediate
bert.encoder.layer.0.intermediate.dense
bert.encoder.layer.0.intermediate.intermediate_act_fn
bert.encoder.layer.0.output
bert.encoder.layer.0.output.dense
bert.encoder.layer.0.output.LayerNorm
bert.encoder.layer.0.output.dropout
bert.encoder.layer.1
bert.encoder.layer.1.attention
bert.encoder.layer.1.attention.self
bert.e

In [4]:
BATCH_SIZE = 32

In [5]:
# Load the GLUE SST-2 dataset
# Load SST-2 validation data
dataset = load_dataset("glue", "sst2", split="validation")
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")

def tokenize(batch):
    return tokenizer(batch['sentence'], padding='max_length', truncation=True, max_length=128)

dataset = dataset.map(tokenize, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)



In [6]:
def collect_activations(model, dataloader, layer):
    activations = []

    def hook_fn(module, input, output):
        activations.append(output.detach())

    # Register a forward hook on the layer to capture activations
    handle = layer.register_forward_hook(hook_fn)

    # Pass the data through the model to collect activations
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch['input_ids'].to(next(model.parameters()).device)
            model(inputs)

    handle.remove()  # Clean up the hook
    X = torch.cat(activations, dim=0)
    if X.ndim > 2:
        X = X.permute(2, 0, 1, *range(3, X.ndim))
        X = X.reshape(X.shape[0], -1)
    else:
        X = X.T
    return X


def low_rank_approximation(X, W, rank):
    """
    Perform data-aware low-rank approximation of W using the DRONE algorithm.
    Args:
        X (torch.Tensor): Matrix of input activations (shape: [d1, n]).
        W (torch.Tensor): Original weight matrix (shape: [d2, d1]).
        rank (int): Target rank for low-rank approximation.
    Returns:
        U_star (torch.Tensor): Left matrix of the low-rank approximation.
        V_star (torch.Tensor): Right matrix of the low-rank approximation.
    """

    print("W_shape", W.shape)
    print("X_shape", X.shape)

    if W.ndim < 2:
        print("Invalid weight matrix with ndim < 2, skipping this layer.")
        return W, False
    elif W.shape[1] < rank:
        print("Rank of weight matrix already lower than k value, skipping this layer.")
        return W, False
    elif W.shape[1] != X.shape[0]:
        print("Invalid weight matrix dimensions, skipping this layer.")
        return W, False

    # Step 1: SVD of W -> W = U_W S_W V_W^T
    U_W, S_W, V_W_T = torch.linalg.svd(W, full_matrices=False)
    U_W_r = U_W[:, :rank]
    S_W_r = S_W[:rank]
    V_W = V_W_T.T
    V_W_r = V_W[:, :rank]

    # Step 2: SVD of X -> X = U_X S_X V_X^T

    t = X.shape[0]
    U_X, S_X, V_X_T = torch.linalg.svd(X, full_matrices=False)
    U_X_t = U_X[:, :t]
    S_X_t = S_X[:t]
    V_X = V_X_T.T
    V_X_t = V_X[:, :t]

    # Step 3: Compute Z = S_W_r V_W_r^T U_X_t S_X_t

    z1 = torch.diag(S_W_r) @ V_W_r.T
    z2 = z1 @ U_X_t
    Z = z2 @ torch.diag(S_X_t)
    #Z = torch.diag(S_W_r) @ V_W_r @ U_X_t.T @ torch.diag(S_X_t)

    # Step 4: Truncated SVD of Z to get Z_k = U_Z,k S_Z,k V_Z,k^T
    U_Z, S_Z, V_Z_T = torch.linalg.svd(Z, full_matrices=False)
    U_Z_k = U_Z[:, :rank]
    S_Z_k = S_Z[:rank]
    V_Z_k = V_Z_T[:rank, :]

    # Step 5: Construct U_star and V_star
    U_star = W @ V_W_r @ torch.diag(1 / S_W_r) @ U_Z_k @ torch.diag(S_Z_k)
    V_star = V_Z_k @ torch.diag(1 / S_X_t) @ U_X_t.T

    # Approximate W with the low-rank matrices
    W_approx = U_star @ V_star
    print("W_approx", W_approx.shape)
    W_approx = W_approx.reshape(W.shape)

    return W_approx, True

def compress_layer(model, dataloader, layer_name, rank):
    """
    Compresses a specified layer in the model using data-aware low-rank approximation.
    """
    # Access layer by name
    layer = dict(model.named_modules())[layer_name]

    # Check if layer has weights
    if hasattr(layer, 'weight'):
        W = layer.weight.data

        # Extract activations for input distribution
        X = collect_activations(model, dataloader, layer)

        # Calculate the low-rank approximation of the layer's weight matrix
        W_approx, successful_compressed = low_rank_approximation(X, W, rank)

        # Update layer's weight with the compressed approximation
        if successful_compressed == True:
          layer.weight.data = W_approx

    return model, successful_compressed
# Example usage:
# compressed_model = compress_layer(model, 'fc1', dataloader, rank=10)

In [7]:
def overall_low_rank_approximation(model, dataloader, k_values, allowed_loss_ratio, layer_names):
    """
    Apply Algorithm 2 to compress a model layer-by-layer using Algorithm 1.
    """
    original_loss = evaluate_model_loss(model, dataloader)
    total_layers = len(layer_names)

    for i, layer_name in enumerate(layer_names):
        layer = dict(model.named_modules())[layer_name]
        original_weights = {}  # Store original weights for potential restoration
        print("------------------")
        print("Layer Name: ", layer_name)
        # Iterate over submodules and store original weights
        for name, submodule in layer.named_modules():
            if hasattr(submodule, 'weight'):
                original_weights[name] = submodule.weight.data.clone()  # Store a copy

        current_k = k_values[i]  # Predefined rank for this layer

        model, result = compress_layer(model, dataloader, layer_name, current_k)
        if result == False:
            continue
        # Evaluate new model loss after compression
        new_loss = evaluate_model_loss(model, dataloader)
        print("original_loss", original_loss)
        print("new_loss", new_loss)
        if new_loss / original_loss < 1 + allowed_loss_ratio:
            print(f"Layer {layer_name} compressed with rank {current_k} under allowed loss ratio.")
        else:
            print(f"Layer {layer_name} compression with rank {current_k} exceeded allowed loss. Skipping.")
            # Restore original weights if compression exceeded allowed loss ratio
            for name, submodule in layer.named_modules():
                if hasattr(submodule, 'weight') and name in original_weights:
                    submodule.weight.data = original_weights[name]

    return model

def evaluate_model_loss(model, dataloader):
    """
    Computes the average loss of the model on the SST-2 dataset.
    """
    model.eval()  # Set model to evaluation mode
    loss_fn = torch.nn.CrossEntropyLoss()  # Define the loss function
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():  # No need to compute gradients
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            labels = batch['label'].to(model.device)

            # Get model outputs
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            # Get model outputs
            num_batches += 1

    # Calculate average loss
    avg_loss = total_loss / num_batches
    return avg_loss

def evaluate_model_accuracy(model, dataloader):
    """
    Computes the accuracy of the model on the SST-2 dataset.
    """
    model.eval()  # Set model to evaluation mode
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():  # No need to compute gradients
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            labels = batch['label'].to(model.device)

            # Get model outputs (logits)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Predicted class (0 or 1)
            predictions = torch.argmax(logits, dim=-1)

            # Update counts for accuracy calculation
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)

    # Calculate accuracy
    accuracy = correct_predictions / total_predictions
    return accuracy

In [8]:
acc = evaluate_model_accuracy(model, dataloader)
print(f"Average accuracy of the model: {acc}")

Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

Average accuracy of the model: 0.9243119266055045


In [None]:

layer_names = [name for name, module in model.named_modules() if hasattr(module, 'weight')]
k_values = [600] * 1000 # Chosen rank k for each layer
allowed_loss_ratio = 0.5  # 50% allowed loss increase

# Run the overall low-rank approximation
compressed_model = overall_low_rank_approximation(model, dataloader, k_values, allowed_loss_ratio, layer_names)

acc = evaluate_model_accuracy(compressed_model, dataloader)
print(f"Average accuracy of the compressed model: {acc}")

Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

------------------
Layer Name:  bert.embeddings.word_embeddings
W_shape torch.Size([30522, 768])
X_shape torch.Size([768, 111616])
W_approx torch.Size([30522, 768])


Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

original_loss 0.28877695210810217
new_loss 0.35083922411182095
Layer bert.embeddings.word_embeddings compressed with rank 600 under allowed loss ratio.
------------------
Layer Name:  bert.embeddings.position_embeddings
W_shape torch.Size([512, 768])
X_shape torch.Size([768, 3584])
W_approx torch.Size([512, 768])


Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

original_loss 0.28877695210810217
new_loss 0.8645089928592954
Layer bert.embeddings.position_embeddings compression with rank 600 exceeded allowed loss. Skipping.
------------------
Layer Name:  bert.embeddings.token_type_embeddings
W_shape torch.Size([2, 768])
X_shape torch.Size([768, 111616])
W_approx torch.Size([2, 768])


Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

original_loss 0.28877695210810217
new_loss nan
Layer bert.embeddings.token_type_embeddings compression with rank 600 exceeded allowed loss. Skipping.
------------------
Layer Name:  bert.embeddings.LayerNorm
W_shape torch.Size([768])
X_shape torch.Size([768, 111616])
Invalid weight matrix with ndim < 2, skipping this layer.
------------------
Layer Name:  bert.encoder.layer.0.attention.self.query
W_shape torch.Size([768, 768])
X_shape torch.Size([768, 111616])
W_approx torch.Size([768, 768])


Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

original_loss 0.28877695210810217
new_loss 0.3511813175199287
Layer bert.encoder.layer.0.attention.self.query compressed with rank 600 under allowed loss ratio.
------------------
Layer Name:  bert.encoder.layer.0.attention.self.key
W_shape torch.Size([768, 768])
X_shape torch.Size([768, 111616])
W_approx torch.Size([768, 768])


Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

original_loss 0.28877695210810217
new_loss 0.3504318618880851
Layer bert.encoder.layer.0.attention.self.key compressed with rank 600 under allowed loss ratio.
------------------
Layer Name:  bert.encoder.layer.0.attention.self.value
W_shape torch.Size([768, 768])
X_shape torch.Size([768, 111616])
W_approx torch.Size([768, 768])


Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

original_loss 0.28877695210810217
new_loss 0.3541932583653501
Layer bert.encoder.layer.0.attention.self.value compressed with rank 600 under allowed loss ratio.
------------------
Layer Name:  bert.encoder.layer.0.attention.output.dense
W_shape torch.Size([768, 768])
X_shape torch.Size([768, 111616])
W_approx torch.Size([768, 768])


Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

original_loss 0.28877695210810217
new_loss 0.35399203148803543
Layer bert.encoder.layer.0.attention.output.dense compressed with rank 600 under allowed loss ratio.
------------------
Layer Name:  bert.encoder.layer.0.attention.output.LayerNorm
W_shape torch.Size([768])
X_shape torch.Size([768, 111616])
Invalid weight matrix with ndim < 2, skipping this layer.
------------------
Layer Name:  bert.encoder.layer.0.intermediate.dense
W_shape torch.Size([3072, 768])
X_shape torch.Size([3072, 111616])
Invalid weight matrix dimensions, skipping this layer.
------------------
Layer Name:  bert.encoder.layer.0.output.dense
W_shape torch.Size([768, 3072])
X_shape torch.Size([768, 111616])
Invalid weight matrix dimensions, skipping this layer.
------------------
Layer Name:  bert.encoder.layer.0.output.LayerNorm
W_shape torch.Size([768])
X_shape torch.Size([768, 111616])
Invalid weight matrix with ndim < 2, skipping this layer.
------------------
Layer Name:  bert.encoder.layer.1.attention.self.q

Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]

original_loss 0.28877695210810217
new_loss 0.35354472710085766
Layer bert.encoder.layer.1.attention.self.query compressed with rank 600 under allowed loss ratio.
------------------
Layer Name:  bert.encoder.layer.1.attention.self.key
