In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import os

# Force to use only GPU 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

class HiRALayer(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 32,
        lora_alpha: int = 32,
        lora_dropout: float = 0.1,
    ):
        super().__init__()
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = lora_alpha / r
        
        self.lora_A = nn.Parameter(torch.zeros(in_features, r))
        self.lora_B = nn.Parameter(torch.randn(r, out_features))
        
        nn.init.zeros_(self.lora_A)
        nn.init.kaiming_uniform_(self.lora_B, a=0)
        
        self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0 else nn.Identity()
        
    def forward(self, x: torch.Tensor, W0: torch.Tensor) -> torch.Tensor:
        result = F.linear(x, W0)
        
        lora_update = self.lora_A @ self.lora_B  # [in_features, out_features]
        
        hadamard_update = W0.T * lora_update  # [in_features, out_features]
        
        result += self.lora_dropout(x @ hadamard_update) * self.scaling
        
        return result


class HiRALinear(nn.Module):
    def __init__(self, linear_layer: nn.Linear, r: int = 32, lora_alpha: int = 32):
        super().__init__()
        self.linear = linear_layer
        self.hira = HiRALayer(
            in_features=linear_layer.in_features,
            out_features=linear_layer.out_features,
            r=r,
            lora_alpha=lora_alpha
        )
        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.bias.requires_grad = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.hira(x, self.linear.weight)

In [2]:
import random
import numpy as np

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


In [3]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
import torch.nn as nn

def apply_hira_to_model(model, r=32, lora_alpha=32, target_modules=['q_lin', 'k_lin', 'v_lin', 'out_lin', 'ffn.lin1', 'ffn.lin2']):

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if any(target in name for target in target_modules):
                parent_name = '.'.join(name.split('.')[:-1])
                attr_name = name.split('.')[-1]
                
                if parent_name:
                    parent_module = model.get_submodule(parent_name)
                else:
                    parent_module = model
                
                hira_layer = HiRALinear(module, r=r, lora_alpha=lora_alpha)
                setattr(parent_module, attr_name, hira_layer)
                
                print(f"Applied HiRA to: {name}")
    
    return model

def get_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    return trainable_params, all_params, 100 * trainable_params / all_params


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

def load_sst2_data(tokenizer, max_length=128):
    """8:1:1 """
    raw = load_dataset("glue", "sst2")  # 有 train / validation / test[web:110]

    train_valid = raw["train"]
    train_valid = train_valid.shuffle(seed=42)
    n = len(train_valid)
    n_train = int(0.8 * n)
    n_val = int(0.1 * n)
    n_test = n - n_train - n_val

    train_dataset = train_valid.select(range(n_train))
    val_dataset   = train_valid.select(range(n_train, n_train + n_val))
    test_dataset  = train_valid.select(range(n_train + n_val, n))

    def preprocess_function(examples):
        enc = tokenizer(
            examples["sentence"],
            truncation=True,
            max_length=max_length
        )
        enc["labels"] = examples["label"]
        return enc

    train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
    val_dataset   = val_dataset.map(preprocess_function,   batched=True, remove_columns=val_dataset.column_names)
    test_dataset  = test_dataset.map(preprocess_function,  batched=True, remove_columns=test_dataset.column_names)

    return {
        "train": train_dataset,
        "validation": val_dataset,
        "test": test_dataset,
    }



def load_imdb_data(tokenizer, max_length=256):

    raw = load_dataset("imdb")  

    train_full = raw["train"]               
    train_full = train_full.shuffle(seed=42)
    n = len(train_full)
    n_train = int(0.8 * n)
    n_val = int(0.1 * n)
    n_test = n - n_train - n_val

    train_dataset = train_full.select(range(n_train))
    val_dataset   = train_full.select(range(n_train, n_train + n_val))
    test_dataset  = train_full.select(range(n_train + n_val, n))

    def preprocess_function(examples):
        enc = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
        )
        enc["labels"] = examples["label"]
        return enc

    train_dataset = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=train_dataset.column_names,
    )
    val_dataset = val_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=val_dataset.column_names,
    )
    test_dataset = test_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=test_dataset.column_names,
    )

    return {
        "train": train_dataset,
        "validation": val_dataset,
        "test": test_dataset,
    }



# After changing

In [5]:
import os
import time
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from tqdm import tqdm
import evaluate
from datasets import load_dataset
import pandas as pd




def get_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for p in model.parameters():
        num = p.numel()
        all_params += num
        if p.requires_grad:
            trainable_params += num
    percentage = 100.0 * trainable_params / all_params if all_params > 0 else 0.0
    return trainable_params, all_params, percentage


def get_model_sparsity(model, threshold: float = 1e-3) -> float:
    total_elems = 0
    small_elems = 0
    for p in model.parameters():
        if p is None:
            continue
        data = p.detach()
        total_elems += data.numel()
        small_elems += (data.abs() < threshold).sum().item()
    if total_elems == 0:
        return 0.0
    return small_elems / total_elems

def train_hira_model(
    dataset_name: str = "sst2",
    model_name: str = "distilbert-base-uncased",
    r: int = 32,
    lora_alpha: int = 32,
    num_epochs: int = 30,
    batch_size: int = 16,
    learning_rate: float = 5e-4,
    weight_decay: float = 0.01,
    warmup_steps: int = 100,
    logging_steps: int = 100,
    max_length: int = 128,
    early_stop_patience: int = 3,
    output_dir: str = "./results_hira",
    resume_from_checkpoint: bool = False,
):

    device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"[{dataset_name}][r={r}] Using device: {device}")

    # Check for existing checkpoint in /hpc/group/xielab/hl385/LoRA
    checkpoint_path = os.path.join("/hpc/group/xielab/hl385/LoRA", f"checkpoint_{dataset_name}_r{r}.pt")
    start_epoch = 0
    if resume_from_checkpoint and os.path.exists(checkpoint_path):
        print(f"[{dataset_name}][r={r}] Found checkpoint: {checkpoint_path}")
        print(f"[{dataset_name}][r={r}] Resuming training from checkpoint...\n")
    else:
        checkpoint_path = None

    # Reset peak GPU memory stats
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        return_tensors="pt",
    )

    num_labels = 2
    model = DistilBertForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
    )
    model = apply_hira_to_model(model, r=r, lora_alpha=lora_alpha)
    model = model.to(device)

    trainable_params, all_params, percentage = get_trainable_parameters(model)
    print(
        f"[{dataset_name}][r={r}] Trainable params: {trainable_params:,} || "
        f"All params: {all_params:,} || Trainable%: {percentage:.4f}%"
    )

    if dataset_name == "sst2":
        dataset = load_sst2_data(tokenizer, max_length)
    elif dataset_name == "imdb":
        dataset = load_imdb_data(tokenizer, max_length)
    else:
        raise ValueError(f"Unknown dataset_name: {dataset_name}")

    train_dataloader = DataLoader(
        dataset["train"],
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator,
    )
    val_split_name = "validation" if "validation" in dataset else "test"
    val_dataloader = DataLoader(
        dataset[val_split_name],
        batch_size=batch_size,
        collate_fn=data_collator,
    )
    test_dataloader = DataLoader(
        dataset["test"],
        batch_size=batch_size,
        collate_fn=data_collator,
    )

    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    os.makedirs(output_dir, exist_ok=True)

    # Load checkpoint if resuming
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"]
        best_val_f1 = checkpoint["best_val_f1"]
        best_val_accuracy = checkpoint["best_val_accuracy"]
        best_epoch = checkpoint["best_epoch"]
        print(f"[{dataset_name}][r={r}] Resumed from epoch {start_epoch}")
        print(f"[{dataset_name}][r={r}] Best F1 so far: {best_val_f1:.4f} at epoch {best_epoch}\n")
    else:
        best_val_accuracy = 0.0
        best_val_f1 = 0.0
        best_epoch = 0

    epoch_times = []
    total_train_time = 0.0
    epochs_without_improvement = 0
    early_stopped = False

    for epoch in range(start_epoch, num_epochs):
        start_t = time.perf_counter()

        model.train()
        total_loss = 0.0
        progress_bar = tqdm(
            train_dataloader,
            desc=f"[{dataset_name}][r={r}] Epoch {epoch + 1}/{num_epochs}",
            ncols=100,
            leave=False,
        )
        for step, batch in enumerate(progress_bar):
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            total_loss += loss.item()
            avg_loss = total_loss / (step + 1)
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "avg_loss": f"{avg_loss:.4f}"})

        avg_train_loss = total_loss / len(train_dataloader)

        # Validation
        acc_metric = evaluate.load("accuracy")
        f1_metric = evaluate.load("f1")
        model.eval()
        for batch in val_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            preds = torch.argmax(outputs.logits, dim=-1)
            acc_metric.add_batch(predictions=preds, references=batch["labels"])
            f1_metric.add_batch(predictions=preds, references=batch["labels"])
        val_acc = acc_metric.compute()["accuracy"]
        val_f1 = f1_metric.compute()["f1"]

        end_t = time.perf_counter()
        epoch_time = end_t - start_t
        epoch_times.append(epoch_time)
        total_train_time += epoch_time

        print(
            f"[{dataset_name}][r={r}] Epoch {epoch + 1} | "
            f"train_loss={avg_train_loss:.4f} | "
            f"val_acc={val_acc:.4f} | "
            f"val_f1={val_f1:.4f} | "
            f"time={epoch_time:.2f}s"
        )

        # Update best model and reset counter if improved
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_val_accuracy = val_acc
            best_epoch = epoch + 1
            epochs_without_improvement = 0

            save_path = os.path.join("/hpc/group/xielab/hl385/LoRA", f"best_model_hira_{dataset_name}_r{r}.pt")
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "val_f1": best_val_f1,
                    "val_accuracy": best_val_accuracy,
                },
                save_path,
            )
        else:
            # F1 did not exceed best_val_f1
            epochs_without_improvement += 1
        
        # Early stopping: stop when no improvement for patience epochs
        if epochs_without_improvement >= early_stop_patience:
            early_stopped = True
            print(
                f"\n[Early Stopping] No improvement over best F1 for {early_stop_patience} consecutive epochs. "
                f"Stopped at epoch {epoch + 1}. Best F1: {best_val_f1:.4f} at epoch {best_epoch}\n"
            )
            break
        
        # Save checkpoint to /hpc/group/xielab/hl385/LoRA
        checkpoint_save_path = os.path.join("/hpc/group/xielab/hl385/LoRA", f"checkpoint_hira_{dataset_name}_r{r}.pt")
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "best_val_f1": best_val_f1,
                "best_val_accuracy": best_val_accuracy,
                "best_epoch": best_epoch,
            },
            checkpoint_save_path,
        )
    
    avg_time_per_epoch = sum(epoch_times) / len(epoch_times) if epoch_times else 0.0
    
    # Get peak GPU memory
    peak_gpu_memory_mb = 0.0
    if torch.cuda.is_available():
        peak_gpu_memory_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
    
    print(
        f"[{dataset_name}][r={r}] Training completed! "
        f"best_val_acc={best_val_accuracy:.4f}, "
        f"best_val_f1={best_val_f1:.4f}, "
        f"avg_time/epoch={avg_time_per_epoch:.2f}s, "
        f"best_epoch={best_epoch}, "
        f"early_stopped={early_stopped}, "
        f"peak_gpu_memory={peak_gpu_memory_mb:.2f}MB, "
        f"trainable_params={trainable_params}, "
        f"trainable_ratio={percentage:.4f}%"
    )

    # Load best model for testing
    best_model_path = os.path.join(output_dir, f"best_model_r{r}.pt")
    if os.path.exists(best_model_path):
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        print(f"[{dataset_name}][r={r}] Loaded best model from {best_model_path}")
    else:
        print(f"[{dataset_name}][r={r}] Warning: Best model not found, using current model")

    # Test evaluation
    acc_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")
    model.eval()
    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        preds = torch.argmax(outputs.logits, dim=-1)
        acc_metric.add_batch(predictions=preds, references=batch["labels"])
        f1_metric.add_batch(predictions=preds, references=batch["labels"])

    test_acc = acc_metric.compute()["accuracy"]
    test_f1 = f1_metric.compute()["f1"]
    print(f"[{dataset_name}][r={r}] Final test accuracy: {test_acc:.4f}, test_f1={test_f1:.4f}")

    sparsity = get_model_sparsity(model, threshold=1e-3)

    return {
        "dataset": dataset_name,
        "rank": r,
        "final_test_accuracy": test_acc,
        "final_test_f1": test_f1,
        "final_val_accuracy": best_val_accuracy,
        "final_val_f1": best_val_f1,
        "total_parameters": all_params,
        "trainable_parameters": trainable_params,
        "trainable_percentage": percentage,
        "total_training_time": total_train_time,
        "average_epoch_time": avg_time_per_epoch,
        "best_f1_epoch": best_epoch,
        "best_val_f1": best_val_f1,
        "peak_gpu_memory_mb": peak_gpu_memory_mb,
        "sparsity": sparsity,
        "early_stopped": early_stopped,
    }


def generate_summary_table(results, save_prefix="HiRA"):
    """Generate and save summary table with all metrics"""
    
    summary_data = []
    for result in results:
        summary_data.append({
            "Rank": result["rank"],
            "Test Acc": f"{result['final_test_accuracy']:.4f}",
            "Test F1": f"{result['final_test_f1']:.4f}",
            "Val Acc": f"{result['final_val_accuracy']:.4f}",
            "Val F1": f"{result['final_val_f1']:.4f}",
            "Best Val F1": f"{result['best_val_f1']:.4f}",
            "Best F1 Epoch": result['best_f1_epoch'],
            "Trainable Params": f"{result['trainable_parameters']:,}",
            "Trainable %": f"{result['trainable_percentage']:.2f}%",
            "Total Time (s)": f"{result['total_training_time']:.2f}",
            "Avg Epoch (s)": f"{result['average_epoch_time']:.2f}",
            "Early Stopped": "Yes" if result['early_stopped'] else "No",
            "Peak GPU (MB)": f"{result['peak_gpu_memory_mb']:.2f}",
            "Sparsity (<1e-3)": f"{result['sparsity']*100:.2f}%",
        })
    
    summary_df = pd.DataFrame(summary_data)
    
    # Print summary table
    print("\n" + "="*80)
    print("BENCHMARK SUMMARY")
    print("="*80)
    print(summary_df.to_string(index=False))
    
    # Save summary to CSV
    csv_filename = f"{save_prefix}_benchmark_summary.csv"
    summary_df.to_csv(csv_filename, index=False)
    print(f"\n✓ Summary saved to '{csv_filename}'")
    print("="*80)
    
    return summary_df

In [6]:
print("=" * 50)
print("Training HiRA on SST-2 with different ranks")
print("=" * 50)

results = []
for r in [2, 4, 8, 16]:
    print("\n" + "-" * 30)
    print(f"HiRA with rank r={r}")
    print("-" * 30)
    res = train_hira_model(
        dataset_name="sst2",
        model_name="distilbert-base-uncased",
        r=r,
        lora_alpha=32,
        num_epochs=30,        
        batch_size=32,
        learning_rate=5e-4,
        weight_decay=0.01,
        warmup_steps=100,
        logging_steps=100,
        max_length=128,
        early_stop_patience=3,
        output_dir="./results_hira",
        resume_from_checkpoint=False,
    )
    results.append(res)

# Generate and save summary table
generate_summary_table(results, save_prefix="HiRA_SST2")

print("\nSummary over ranks:")
for res in results:
    print(
        f"Rank={res['rank']}, "
        f"Test Acc={res['final_test_accuracy']:.4f}, "
        f"Test F1={res['final_test_f1']:.4f}, "
        f"Val Acc={res['final_val_accuracy']:.4f}, "
        f"Val F1={res['final_val_f1']:.4f}, "
        f"Best Val F1={res['best_val_f1']:.4f}, "
        f"Best F1 Epoch={res['best_f1_epoch']}, "
        f"Trainable Params={res['trainable_parameters']:,}, "
        f"Trainable %={res['trainable_percentage']:.2f}%, "
        f"Total Time (s)={res['total_training_time']:.2f}, "
        f"Avg Epoch (s)={res['average_epoch_time']:.2f}, "
        f"Early Stopped={'Yes' if res['early_stopped'] else 'No'}, "
        f"Peak GPU (MB)={res['peak_gpu_memory_mb']:.2f}, "
        f"Sparsity (<1e-3)={res['sparsity']*100:.2f}%"
    )

Training HiRA on SST-2 with different ranks

------------------------------
HiRA with rank r=2
------------------------------
[sst2][r=2] Using device: cuda


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

                                                                                                    

[sst2][r=2] Epoch 1 | train_loss=0.2431 | val_acc=0.9307 | val_f1=0.9369 | time=73.12s


                                                                                                    

[sst2][r=2] Epoch 2 | train_loss=0.0996 | val_acc=0.9361 | val_f1=0.9430 | time=70.74s


                                                                                                    

[sst2][r=2] Epoch 3 | train_loss=0.0574 | val_acc=0.9379 | val_f1=0.9444 | time=74.40s


                                                                                                    

[sst2][r=2] Epoch 4 | train_loss=0.0367 | val_acc=0.9373 | val_f1=0.9442 | time=73.94s


                                                                                                    

[sst2][r=2] Epoch 5 | train_loss=0.0285 | val_acc=0.9329 | val_f1=0.9398 | time=73.87s


                                                                                                    

[sst2][r=2] Epoch 6 | train_loss=0.0205 | val_acc=0.9330 | val_f1=0.9401 | time=71.63s

[Early Stopping] No improvement over best F1 for 3 consecutive epochs. Stopped at epoch 6. Best F1: 0.9444 at epoch 3

[sst2][r=2] Training completed! best_val_acc=0.9379, best_val_f1=0.9444, avg_time/epoch=72.95s, best_epoch=3, early_stopped=True, peak_gpu_memory=1384.98MB, trainable_params=24612098, trainable_ratio=36.6683%
[sst2][r=2] Final test accuracy: 0.9363, test_f1=0.9439

------------------------------
HiRA with rank r=4
------------------------------
[sst2][r=4] Using device: cuda


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

                                                                                                    

[sst2][r=4] Epoch 1 | train_loss=0.2436 | val_acc=0.9369 | val_f1=0.9443 | time=73.82s


                                                                                                    

[sst2][r=4] Epoch 2 | train_loss=0.1023 | val_acc=0.9376 | val_f1=0.9445 | time=72.03s


                                                                                                    

[sst2][r=4] Epoch 3 | train_loss=0.0583 | val_acc=0.9373 | val_f1=0.9442 | time=73.71s


                                                                                                    

[sst2][r=4] Epoch 4 | train_loss=0.0390 | val_acc=0.9372 | val_f1=0.9439 | time=73.12s


                                                                                                    

[sst2][r=4] Epoch 5 | train_loss=0.0274 | val_acc=0.9333 | val_f1=0.9409 | time=73.60s

[Early Stopping] No improvement over best F1 for 3 consecutive epochs. Stopped at epoch 5. Best F1: 0.9445 at epoch 2

[sst2][r=4] Training completed! best_val_acc=0.9376, best_val_f1=0.9445, avg_time/epoch=73.25s, best_epoch=2, early_stopped=True, peak_gpu_memory=1835.24MB, trainable_params=24777986, trainable_ratio=36.8244%
[sst2][r=4] Final test accuracy: 0.9350, test_f1=0.9430

------------------------------
HiRA with rank r=8
------------------------------
[sst2][r=8] Using device: cuda


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

                                                                                                    

[sst2][r=8] Epoch 1 | train_loss=0.2416 | val_acc=0.9342 | val_f1=0.9424 | time=74.42s


                                                                                                    

[sst2][r=8] Epoch 2 | train_loss=0.1017 | val_acc=0.9378 | val_f1=0.9444 | time=73.71s


                                                                                                    

[sst2][r=8] Epoch 3 | train_loss=0.0560 | val_acc=0.9402 | val_f1=0.9470 | time=72.26s


                                                                                                    

[sst2][r=8] Epoch 4 | train_loss=0.0383 | val_acc=0.9378 | val_f1=0.9445 | time=74.21s


                                                                                                    

[sst2][r=8] Epoch 5 | train_loss=0.0265 | val_acc=0.9339 | val_f1=0.9418 | time=73.86s


                                                                                                    

[sst2][r=8] Epoch 6 | train_loss=0.0207 | val_acc=0.9347 | val_f1=0.9414 | time=73.27s

[Early Stopping] No improvement over best F1 for 3 consecutive epochs. Stopped at epoch 6. Best F1: 0.9470 at epoch 3

[sst2][r=8] Training completed! best_val_acc=0.9402, best_val_f1=0.9470, avg_time/epoch=73.62s, best_epoch=3, early_stopped=True, peak_gpu_memory=1838.29MB, trainable_params=25109762, trainable_ratio=37.1344%
[sst2][r=8] Final test accuracy: 0.9353, test_f1=0.9427

------------------------------
HiRA with rank r=16
------------------------------
[sst2][r=16] Using device: cuda


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

                                                                                                    

[sst2][r=16] Epoch 1 | train_loss=0.2506 | val_acc=0.9342 | val_f1=0.9422 | time=73.64s


                                                                                                    

[sst2][r=16] Epoch 2 | train_loss=0.1051 | val_acc=0.9376 | val_f1=0.9439 | time=73.97s


                                                                                                    

[sst2][r=16] Epoch 3 | train_loss=0.0593 | val_acc=0.9366 | val_f1=0.9433 | time=74.54s


                                                                                                    

[sst2][r=16] Epoch 4 | train_loss=0.0369 | val_acc=0.9345 | val_f1=0.9413 | time=73.63s


                                                                                                    

[sst2][r=16] Epoch 5 | train_loss=0.0256 | val_acc=0.9354 | val_f1=0.9424 | time=73.55s

[Early Stopping] No improvement over best F1 for 3 consecutive epochs. Stopped at epoch 5. Best F1: 0.9439 at epoch 2

[sst2][r=16] Training completed! best_val_acc=0.9376, best_val_f1=0.9439, avg_time/epoch=73.86s, best_epoch=2, early_stopped=True, peak_gpu_memory=2296.67MB, trainable_params=25773314, trainable_ratio=37.7453%
[sst2][r=16] Final test accuracy: 0.9371, test_f1=0.9448

BENCHMARK SUMMARY
 Rank Test Acc Test F1 Val Acc Val F1 Best Val F1  Best F1 Epoch Trainable Params Trainable % Total Time (s) Avg Epoch (s) Early Stopped Peak GPU (MB) Sparsity (<1e-3)
    2   0.9363  0.9439  0.9379 0.9444      0.9444              3       24,612,098      36.67%         437.71         72.95           Yes       1384.98            1.71%
    4   0.9350  0.9430  0.9376 0.9445      0.9445              2       24,777,986      36.82%         366.27         73.25           Yes       1835.24            1.72%
  

In [7]:
# Re-generate summary table if needed
generate_summary_table(results, save_prefix="HiRA_SST2")


BENCHMARK SUMMARY
 Rank Test Acc Test F1 Val Acc Val F1 Best Val F1  Best F1 Epoch Trainable Params Trainable % Total Time (s) Avg Epoch (s) Early Stopped Peak GPU (MB) Sparsity (<1e-3)
    2   0.9363  0.9439  0.9379 0.9444      0.9444              3       24,612,098      36.67%         437.71         72.95           Yes       1384.98            1.71%
    4   0.9350  0.9430  0.9376 0.9445      0.9445              2       24,777,986      36.82%         366.27         73.25           Yes       1835.24            1.72%
    8   0.9353  0.9427  0.9402 0.9470      0.9470              3       25,109,762      37.13%         441.73         73.62           Yes       1838.29            1.71%
   16   0.9371  0.9448  0.9376 0.9439      0.9439              2       25,773,314      37.75%         369.32         73.86           Yes       2296.67            1.72%

✓ Summary saved to 'HiRA_SST2_benchmark_summary.csv'


Unnamed: 0,Rank,Test Acc,Test F1,Val Acc,Val F1,Best Val F1,Best F1 Epoch,Trainable Params,Trainable %,Total Time (s),Avg Epoch (s),Early Stopped,Peak GPU (MB),Sparsity (<1e-3)
0,2,0.9363,0.9439,0.9379,0.9444,0.9444,3,24612098,36.67%,437.71,72.95,Yes,1384.98,1.71%
1,4,0.935,0.943,0.9376,0.9445,0.9445,2,24777986,36.82%,366.27,73.25,Yes,1835.24,1.72%
2,8,0.9353,0.9427,0.9402,0.947,0.947,3,25109762,37.13%,441.73,73.62,Yes,1838.29,1.71%
3,16,0.9371,0.9448,0.9376,0.9439,0.9439,2,25773314,37.75%,369.32,73.86,Yes,2296.67,1.72%
