In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import os

class HiRALayer(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 32,
        lora_alpha: int = 32,
        lora_dropout: float = 0.1,
    ):
        super().__init__()
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = lora_alpha / r
        
        self.lora_A = nn.Parameter(torch.zeros(in_features, r))
        self.lora_B = nn.Parameter(torch.randn(r, out_features))
        
        nn.init.zeros_(self.lora_A)
        nn.init.kaiming_uniform_(self.lora_B, a=0)
        
        self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0 else nn.Identity()
        
    def forward(self, x: torch.Tensor, W0: torch.Tensor) -> torch.Tensor:
        result = F.linear(x, W0)
        
        lora_update = self.lora_A @ self.lora_B  # [in_features, out_features]
        
        hadamard_update = W0.T * lora_update  # [in_features, out_features]
        
        result += self.lora_dropout(x @ hadamard_update) * self.scaling
        
        return result


class HiRALinear(nn.Module):
    def __init__(self, linear_layer: nn.Linear, r: int = 32, lora_alpha: int = 32):
        super().__init__()
        self.linear = linear_layer
        self.hira = HiRALayer(
            in_features=linear_layer.in_features,
            out_features=linear_layer.out_features,
            r=r,
            lora_alpha=lora_alpha
        )
        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.bias.requires_grad = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.hira(x, self.linear.weight)




In [4]:
import random
import numpy as np

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


In [5]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
import torch.nn as nn

def apply_hira_to_model(model, r=32, lora_alpha=32, target_modules=['q_lin', 'k_lin', 'v_lin', 'out_lin', 'ffn.lin1', 'ffn.lin2']):

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if any(target in name for target in target_modules):
                parent_name = '.'.join(name.split('.')[:-1])
                attr_name = name.split('.')[-1]
                
                if parent_name:
                    parent_module = model.get_submodule(parent_name)
                else:
                    parent_module = model
                
                hira_layer = HiRALinear(module, r=r, lora_alpha=lora_alpha)
                setattr(parent_module, attr_name, hira_layer)
                
                print(f"Applied HiRA to: {name}")
    
    return model

def get_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    return trainable_params, all_params, 100 * trainable_params / all_params


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

def load_sst2_data(tokenizer, max_length=128):
    """8:1:1 """
    raw = load_dataset("glue", "sst2")  # 有 train / validation / test[web:110]

    train_valid = raw["train"]
    train_valid = train_valid.shuffle(seed=42)
    n = len(train_valid)
    n_train = int(0.8 * n)
    n_val = int(0.1 * n)
    n_test = n - n_train - n_val

    train_dataset = train_valid.select(range(n_train))
    val_dataset   = train_valid.select(range(n_train, n_train + n_val))
    test_dataset  = train_valid.select(range(n_train + n_val, n))

    def preprocess_function(examples):
        enc = tokenizer(
            examples["sentence"],
            truncation=True,
            max_length=max_length
        )
        enc["labels"] = examples["label"]
        return enc

    train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
    val_dataset   = val_dataset.map(preprocess_function,   batched=True, remove_columns=val_dataset.column_names)
    test_dataset  = test_dataset.map(preprocess_function,  batched=True, remove_columns=test_dataset.column_names)

    return {
        "train": train_dataset,
        "validation": val_dataset,
        "test": test_dataset,
    }



def load_imdb_data(tokenizer, max_length=256):

    raw = load_dataset("imdb")  

    train_full = raw["train"]               
    train_full = train_full.shuffle(seed=42)
    n = len(train_full)
    n_train = int(0.8 * n)
    n_val = int(0.1 * n)
    n_test = n - n_train - n_val

    train_dataset = train_full.select(range(n_train))
    val_dataset   = train_full.select(range(n_train, n_train + n_val))
    test_dataset  = train_full.select(range(n_train + n_val, n))

    def preprocess_function(examples):
        enc = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
        )
        enc["labels"] = examples["label"]
        return enc

    train_dataset = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=train_dataset.column_names,
    )
    val_dataset = val_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=val_dataset.column_names,
    )
    test_dataset = test_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=test_dataset.column_names,
    )

    return {
        "train": train_dataset,
        "validation": val_dataset,
        "test": test_dataset,
    }



# After changing

In [None]:
import os
import time
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from tqdm import tqdm
import evaluate
from datasets import load_dataset
import pandas as pd

# Checkpoint directory configuration
CHECKPOINT_DIR = "/hpc/group/xielab/hl385/LoRA"




def get_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for p in model.parameters():
        num = p.numel()
        all_params += num
        if p.requires_grad:
            trainable_params += num
    percentage = 100.0 * trainable_params / all_params if all_params > 0 else 0.0
    return trainable_params, all_params, percentage


def get_model_sparsity(model, threshold: float = 1e-3) -> float:
    total_elems = 0
    small_elems = 0
    for p in model.parameters():
        if p is None:
            continue
        data = p.detach()
        total_elems += data.numel()
        small_elems += (data.abs() < threshold).sum().item()
    if total_elems == 0:
        return 0.0
    return small_elems / total_elems

def train_hira_model(
    dataset_name: str = "sst2",
    model_name: str = "distilbert-base-uncased",
    r: int = 32,
    lora_alpha: int = 32,
    num_epochs: int = 30,
    batch_size: int = 16,
    learning_rate: float = 5e-4,
    weight_decay: float = 0.01,
    warmup_steps: int = 100,
    logging_steps: int = 100,
    max_length: int = 128,
    early_stop_patience: int = 3,
    output_dir: str = "./results_hira",
    resume_from_checkpoint: bool = False,
):

    device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"[{dataset_name}][r={r}] Using device: {device}")

    # Check for existing checkpoint
    checkpoint_path = os.path.join(output_dir, f"checkpoint_r{r}.pt")
    start_epoch = 0
    if resume_from_checkpoint and os.path.exists(checkpoint_path):
        print(f"[{dataset_name}][r={r}] Found checkpoint: {checkpoint_path}")
        print(f"[{dataset_name}][r={r}] Resuming training from checkpoint...\n")
    else:
        checkpoint_path = None

    # Reset peak GPU memory stats
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        return_tensors="pt",
    )

    num_labels = 2
    model = DistilBertForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
    )
    model = apply_hira_to_model(model, r=r, lora_alpha=lora_alpha)
    model = model.to(device)

    trainable_params, all_params, percentage = get_trainable_parameters(model)
    print(
        f"[{dataset_name}][r={r}] Trainable params: {trainable_params:,} || "
        f"All params: {all_params:,} || Trainable%: {percentage:.4f}%"
    )

    if dataset_name == "sst2":
        dataset = load_sst2_data(tokenizer, max_length)
    elif dataset_name == "imdb":
        dataset = load_imdb_data(tokenizer, max_length)
    else:
        raise ValueError(f"Unknown dataset_name: {dataset_name}")

    train_dataloader = DataLoader(
        dataset["train"],
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator,
    )
    val_split_name = "validation" if "validation" in dataset else "test"
    val_dataloader = DataLoader(
        dataset[val_split_name],
        batch_size=batch_size,
        collate_fn=data_collator,
    )
    test_dataloader = DataLoader(
        dataset["test"],
        batch_size=batch_size,
        collate_fn=data_collator,
    )

    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    os.makedirs(output_dir, exist_ok=True)

    # Load checkpoint if resuming
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"]
        best_val_f1 = checkpoint["best_val_f1"]
        best_val_accuracy = checkpoint["best_val_accuracy"]
        best_epoch = checkpoint["best_epoch"]
        print(f"[{dataset_name}][r={r}] Resumed from epoch {start_epoch}")
        print(f"[{dataset_name}][r={r}] Best F1 so far: {best_val_f1:.4f} at epoch {best_epoch}\n")
    else:
        best_val_accuracy = 0.0
        best_val_f1 = 0.0
        best_epoch = 0

    epoch_times = []
    total_train_time = 0.0
    epochs_without_improvement = 0
    early_stopped = False

    for epoch in range(start_epoch, num_epochs):
        start_t = time.perf_counter()

        model.train()
        total_loss = 0.0
        progress_bar = tqdm(
            train_dataloader,
            desc=f"[{dataset_name}][r={r}] Epoch {epoch + 1}/{num_epochs}",
        )
        for step, batch in enumerate(progress_bar):
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})
            
            # Log loss every logging_steps
            if (step + 1) % logging_steps == 0:
                avg_loss = total_loss / (step + 1)
                print(f"  Step {step + 1}/{len(train_dataloader)} | avg_loss={avg_loss:.4f}")

        avg_train_loss = total_loss / len(train_dataloader)

        # Validation
        acc_metric = evaluate.load("accuracy")
        f1_metric = evaluate.load("f1")
        model.eval()
        for batch in val_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            preds = torch.argmax(outputs.logits, dim=-1)
            acc_metric.add_batch(predictions=preds, references=batch["labels"])
            f1_metric.add_batch(predictions=preds, references=batch["labels"])
        val_acc = acc_metric.compute()["accuracy"]
        val_f1 = f1_metric.compute()["f1"]

        end_t = time.perf_counter()
        epoch_time = end_t - start_t
        epoch_times.append(epoch_time)
        total_train_time += epoch_time

        print(
            f"[{dataset_name}][r={r}] Epoch {epoch + 1} | "
            f"train_loss={avg_train_loss:.4f} | "
            f"val_acc={val_acc:.4f} | "
            f"val_f1={val_f1:.4f} | "
            f"time={epoch_time:.2f}s"
        )
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_val_accuracy = val_acc
            best_epoch = epoch + 1

            epochs_without_improvement = 0

            save_path = os.path.join(CHECKPOINT_DIR, f"best_model_hira_{dataset_name}_r{r}.pt")
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "val_f1": best_val_f1,
                    "val_accuracy": best_val_accuracy,
                },
                save_path,
            )
            print(
                f"[{dataset_name}][r={r}] Saved best model with "
                f"val_f1={best_val_f1:.4f} at epoch {best_epoch}"
            )
        else:
            # F1 did not exceed best_val_f1
            epochs_without_improvement += 1
        
        # Early stopping: stop when no improvement for patience epochs
        if epochs_without_improvement >= early_stop_patience:
            early_stopped = True
            print(
                f"\n[Early Stopping] No improvement over best F1 for {early_stop_patience} consecutive epochs. Stopping training."
            )
            print(f"[Best Model] Best F1: {best_val_f1:.4f} at epoch {best_epoch}\n")
            break
        # Save checkpoint to CHECKPOINT_DIR
        checkpoint_save_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_hira_{dataset_name}_r{r}.pt")
        checkpoint_save_path = os.path.join(output_dir, f"checkpoint_r{r}.pt")
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "best_val_f1": best_val_f1,
                "best_val_accuracy": best_val_accuracy,
                "best_epoch": best_epoch,
            },
            checkpoint_save_path,
        )
    
    avg_time_per_epoch = sum(epoch_times) / len(epoch_times) if epoch_times else 0.0
    
    # Get peak GPU memory
    peak_gpu_memory_mb = 0.0
    if torch.cuda.is_available():
        peak_gpu_memory_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
    
    print(
        f"[{dataset_name}][r={r}] Training completed! "
        f"best_val_acc={best_val_accuracy:.4f}, "
        f"best_val_f1={best_val_f1:.4f}, "
        f"avg_time/epoch={avg_time_per_epoch:.2f}s, "
        f"best_epoch={best_epoch}, "
        f"early_stopped={early_stopped}, "
        f"peak_gpu_memory={peak_gpu_memory_mb:.2f}MB, "
        f"trainable_params={trainable_params}, "
        f"trainable_ratio={percentage:.4f}%"
    )

    best_model_path = os.path.join(CHECKPOINT_DIR, f"best_model_hira_{dataset_name}_r{r}.pt")
    best_model_path = os.path.join(output_dir, f"best_model_r{r}.pt")
    if os.path.exists(best_model_path):
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        print(f"[{dataset_name}][r={r}] Loaded best model from {best_model_path}")
    else:
        print(f"[{dataset_name}][r={r}] Warning: Best model not found, using current model")

    # Test evaluation
    acc_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")
    model.eval()
    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        preds = torch.argmax(outputs.logits, dim=-1)
        acc_metric.add_batch(predictions=preds, references=batch["labels"])
        f1_metric.add_batch(predictions=preds, references=batch["labels"])

    test_acc = acc_metric.compute()["accuracy"]
    test_f1 = f1_metric.compute()["f1"]
    print(f"[{dataset_name}][r={r}] Final test accuracy: {test_acc:.4f}, test_f1={test_f1:.4f}")

    sparsity = get_model_sparsity(model, threshold=1e-3)

    return {
        "dataset": dataset_name,
        "rank": r,
        "final_test_accuracy": test_acc,
        "final_test_f1": test_f1,
        "final_val_accuracy": best_val_accuracy,
        "final_val_f1": best_val_f1,
        "total_parameters": all_params,
        "trainable_parameters": trainable_params,
        "trainable_percentage": percentage,
        "total_training_time": total_train_time,
        "average_epoch_time": avg_time_per_epoch,
        "best_f1_epoch": best_epoch,
        "best_val_f1": best_val_f1,
        "peak_gpu_memory_mb": peak_gpu_memory_mb,
        "sparsity": sparsity,
        "early_stopped": early_stopped,
    }
        "early_stopped": early_stopped,

def generate_summary_table(results, save_prefix="HiRA"):
    """Generate and save summary table with all metrics"""
    
    summary_data = []
    for result in results:
        summary_data.append({
            "Rank": result["rank"],
            "Test Acc": f"{result['final_test_accuracy']:.4f}",
            "Test F1": f"{result['final_test_f1']:.4f}",
            "Val Acc": f"{result['final_val_accuracy']:.4f}",
            "Val F1": f"{result['final_val_f1']:.4f}",
            "Best Val F1": f"{result['best_val_f1']:.4f}",
            "Best F1 Epoch": result['best_f1_epoch'],
            "Trainable Params": f"{result['trainable_parameters']:,}",
            "Trainable %": f"{result['trainable_percentage']:.2f}%",
            "Total Time (s)": f"{result['total_training_time']:.2f}",
            "Avg Epoch (s)": f"{result['average_epoch_time']:.2f}",
            "Early Stopped": "Yes" if result['early_stopped'] else "No",
            "Peak GPU (MB)": f"{result['peak_gpu_memory_mb']:.2f}",
            "Sparsity (<1e-3)": f"{result['sparsity']*100:.2f}%",
        })
    
    summary_df = pd.DataFrame(summary_data)
    
    # Print summary table
    print("\n" + "="*80)
    print("BENCHMARK SUMMARY")
    print("="*80)
    print(summary_df.to_string(index=False))
    summary_df = pd.DataFrame(summary_data)
    # Save summary to CSV
    csv_filename = f"{save_prefix}_benchmark_summary.csv"
    summary_df.to_csv(csv_filename, index=False)
    print(f"\n✓ Summary saved to '{csv_filename}'")
    print("="*80)
    

    return summary_df
    print("="*80)    return summary_df
    

In [None]:
print("=" * 50)
print("Training HiRA on IMDB with different ranks")
print("=" * 50)

imdb_results = []
for r in [2, 4, 8, 16]:
    print("\n" + "-" * 30)
    print(f"HiRA with rank r={r}")
    print("-" * 30)
    res = train_hira_model(
        dataset_name="imdb",
        model_name="distilbert-base-uncased",
        r=r,
        lora_alpha=32,
        num_epochs=30,        
        batch_size=32,
        learning_rate=5e-4,
        weight_decay=0.01,
        warmup_steps=100,
        logging_steps=100,
        max_length=256,
        early_stop_patience=3,
        output_dir="./results_hira_imdb",
        resume_from_checkpoint=False,
    )
    imdb_results.append(res)

# Generate and save summary table
generate_summary_table(imdb_results, save_prefix="HiRA_IMDB")

print("\nSummary over ranks:")
for res in imdb_results:
    print(
        f"Rank={res['rank']}, "
        f"Test Acc={res['final_test_accuracy']:.4f}, "
        f"Test F1={res['final_test_f1']:.4f}, "
        f"Val Acc={res['final_val_accuracy']:.4f}, "
        f"Val F1={res['final_val_f1']:.4f}, "
        f"Best Val F1={res['best_val_f1']:.4f}, "
        f"Best F1 Epoch={res['best_f1_epoch']}, "
        f"Trainable Params={res['trainable_parameters']:,}, "
        f"Trainable %={res['trainable_percentage']:.2f}%, "
        f"Total Time (s)={res['total_training_time']:.2f}, "
        f"Avg Epoch (s)={res['average_epoch_time']:.2f}, "
        f"Early Stopped={'Yes' if res['early_stopped'] else 'No'}, "
        f"Peak GPU (MB)={res['peak_gpu_memory_mb']:.2f}, "
        f"Sparsity (<1e-3)={res['sparsity']*100:.2f}%"
    )

Training HiRA on SST-2 with different ranks

------------------------------
HiRA with rank r=2
------------------------------
[imdb][r=2] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

Map: 100%|██████████| 20000/20000 [00:26<00:00, 749.97 examples/s]
Map: 100%|██████████| 2500/2500 [00:03<00:00, 770.62 examples/s]
Map: 100%|██████████| 2500/2500 [00:03<00:00, 758.20 examples/s]
[imdb][r=2] Epoch 1/10: 100%|██████████| 1250/1250 [04:33<00:00,  4.58it/s, loss=0.399]


[imdb][r=2] Epoch 1 | train_loss=0.4822 | val_acc=0.8212 | val_f1=0.8158 | time/epoch=299.35s
[imdb][r=2] Saved best model with val_acc=0.8212 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 2/10: 100%|██████████| 1250/1250 [06:04<00:00,  3.43it/s, loss=0.431] 


[imdb][r=2] Epoch 2 | train_loss=0.3514 | val_acc=0.8300 | val_f1=0.8266 | time/epoch=378.34s
[imdb][r=2] Saved best model with val_acc=0.8300 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 3/10: 100%|██████████| 1250/1250 [05:34<00:00,  3.73it/s, loss=0.652] 


[imdb][r=2] Epoch 3 | train_loss=0.3002 | val_acc=0.8372 | val_f1=0.8379 | time/epoch=350.08s
[imdb][r=2] Saved best model with val_acc=0.8372 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 4/10: 100%|██████████| 1250/1250 [04:59<00:00,  4.17it/s, loss=0.182] 


[imdb][r=2] Epoch 4 | train_loss=0.2598 | val_acc=0.8428 | val_f1=0.8411 | time/epoch=315.13s
[imdb][r=2] Saved best model with val_acc=0.8428 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 5/10: 100%|██████████| 1250/1250 [05:20<00:00,  3.90it/s, loss=0.0897]


[imdb][r=2] Epoch 5 | train_loss=0.2271 | val_acc=0.8472 | val_f1=0.8489 | time/epoch=335.56s
[imdb][r=2] Saved best model with val_acc=0.8472 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 6/10: 100%|██████████| 1250/1250 [05:15<00:00,  3.96it/s, loss=0.163] 


[imdb][r=2] Epoch 6 | train_loss=0.2013 | val_acc=0.8472 | val_f1=0.8470 | time/epoch=331.46s


[imdb][r=2] Epoch 7/10: 100%|██████████| 1250/1250 [05:15<00:00,  3.96it/s, loss=0.119] 


[imdb][r=2] Epoch 7 | train_loss=0.1740 | val_acc=0.8452 | val_f1=0.8435 | time/epoch=330.83s


[imdb][r=2] Epoch 8/10: 100%|██████████| 1250/1250 [05:11<00:00,  4.01it/s, loss=0.49]  


[imdb][r=2] Epoch 8 | train_loss=0.1591 | val_acc=0.8480 | val_f1=0.8475 | time/epoch=327.27s


[imdb][r=2] Epoch 9/10: 100%|██████████| 1250/1250 [05:12<00:00,  4.00it/s, loss=0.432] 


[imdb][r=2] Epoch 9 | train_loss=0.1454 | val_acc=0.8504 | val_f1=0.8508 | time/epoch=327.03s
[imdb][r=2] Saved best model with val_acc=0.8504 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 10/10: 100%|██████████| 1250/1250 [05:17<00:00,  3.94it/s, loss=0.167] 


[imdb][r=2] Epoch 10 | train_loss=0.1406 | val_acc=0.8492 | val_f1=0.8513 | time/epoch=333.76s
[imdb][r=2] Saved best model with val_acc=0.8492 to ./results_hira/best_model_r2.pt
[imdb][r=2] Training completed! best_val_acc=0.8492, avg_time/epoch=332.88s, converge_epoch=1, trainable_params=24612098, trainable_ratio=36.6683%
[imdb][r=2] Final test accuracy: 0.8460, test_f1=0.8523

------------------------------
HiRA with rank r=4
------------------------------
[imdb][r=4] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=4] Epoch 1/10: 100%|██████████| 1250/1250 [05:53<00:00,  3.53it/s, loss=0.287]


[imdb][r=4] Epoch 1 | train_loss=0.4814 | val_acc=0.8192 | val_f1=0.8167 | time/epoch=371.73s
[imdb][r=4] Saved best model with val_acc=0.8192 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 2/10: 100%|██████████| 1250/1250 [06:08<00:00,  3.39it/s, loss=0.383] 


[imdb][r=4] Epoch 2 | train_loss=0.3510 | val_acc=0.8328 | val_f1=0.8321 | time/epoch=386.08s
[imdb][r=4] Saved best model with val_acc=0.8328 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 3/10: 100%|██████████| 1250/1250 [06:05<00:00,  3.42it/s, loss=0.173] 


[imdb][r=4] Epoch 3 | train_loss=0.3018 | val_acc=0.8400 | val_f1=0.8403 | time/epoch=382.79s
[imdb][r=4] Saved best model with val_acc=0.8400 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 4/10: 100%|██████████| 1250/1250 [05:47<00:00,  3.59it/s, loss=0.311] 


[imdb][r=4] Epoch 4 | train_loss=0.2607 | val_acc=0.8392 | val_f1=0.8376 | time/epoch=364.35s


[imdb][r=4] Epoch 5/10: 100%|██████████| 1250/1250 [05:29<00:00,  3.80it/s, loss=0.504] 


[imdb][r=4] Epoch 5 | train_loss=0.2266 | val_acc=0.8428 | val_f1=0.8424 | time/epoch=344.92s
[imdb][r=4] Saved best model with val_acc=0.8428 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 6/10: 100%|██████████| 1250/1250 [05:16<00:00,  3.95it/s, loss=0.131] 


[imdb][r=4] Epoch 6 | train_loss=0.1976 | val_acc=0.8412 | val_f1=0.8411 | time/epoch=332.43s


[imdb][r=4] Epoch 7/10: 100%|██████████| 1250/1250 [05:24<00:00,  3.85it/s, loss=0.151] 


[imdb][r=4] Epoch 7 | train_loss=0.1764 | val_acc=0.8456 | val_f1=0.8468 | time/epoch=340.96s
[imdb][r=4] Saved best model with val_acc=0.8456 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 8/10: 100%|██████████| 1250/1250 [05:33<00:00,  3.75it/s, loss=0.166] 


[imdb][r=4] Epoch 8 | train_loss=0.1588 | val_acc=0.8460 | val_f1=0.8475 | time/epoch=349.01s
[imdb][r=4] Saved best model with val_acc=0.8460 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 9/10: 100%|██████████| 1250/1250 [05:28<00:00,  3.81it/s, loss=0.17]   


[imdb][r=4] Epoch 9 | train_loss=0.1462 | val_acc=0.8448 | val_f1=0.8459 | time/epoch=344.04s


[imdb][r=4] Epoch 10/10: 100%|██████████| 1250/1250 [05:32<00:00,  3.76it/s, loss=0.0692]


[imdb][r=4] Epoch 10 | train_loss=0.1393 | val_acc=0.8444 | val_f1=0.8455 | time/epoch=348.89s
[imdb][r=4] Training completed! best_val_acc=0.8460, avg_time/epoch=356.52s, converge_epoch=1, trainable_params=24777986, trainable_ratio=36.8244%
[imdb][r=4] Final test accuracy: 0.8472, test_f1=0.8526

------------------------------
HiRA with rank r=8
------------------------------
[imdb][r=8] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=8] Epoch 1/10: 100%|██████████| 1250/1250 [05:31<00:00,  3.77it/s, loss=0.361] 


[imdb][r=8] Epoch 1 | train_loss=0.4819 | val_acc=0.8172 | val_f1=0.8165 | time/epoch=346.88s
[imdb][r=8] Saved best model with val_acc=0.8172 to ./results_hira/best_model_r8.pt


[imdb][r=8] Epoch 2/10: 100%|██████████| 1250/1250 [05:22<00:00,  3.88it/s, loss=0.467] 


[imdb][r=8] Epoch 2 | train_loss=0.3487 | val_acc=0.8308 | val_f1=0.8314 | time/epoch=337.78s
[imdb][r=8] Saved best model with val_acc=0.8308 to ./results_hira/best_model_r8.pt


[imdb][r=8] Epoch 3/10: 100%|██████████| 1250/1250 [05:19<00:00,  3.92it/s, loss=0.344] 


[imdb][r=8] Epoch 3 | train_loss=0.2968 | val_acc=0.8360 | val_f1=0.8373 | time/epoch=334.93s
[imdb][r=8] Saved best model with val_acc=0.8360 to ./results_hira/best_model_r8.pt


[imdb][r=8] Epoch 4/10: 100%|██████████| 1250/1250 [05:19<00:00,  3.91it/s, loss=0.442] 


[imdb][r=8] Epoch 4 | train_loss=0.2576 | val_acc=0.8412 | val_f1=0.8434 | time/epoch=334.71s
[imdb][r=8] Saved best model with val_acc=0.8412 to ./results_hira/best_model_r8.pt


[imdb][r=8] Epoch 5/10: 100%|██████████| 1250/1250 [05:17<00:00,  3.93it/s, loss=0.346] 


[imdb][r=8] Epoch 5 | train_loss=0.2241 | val_acc=0.8420 | val_f1=0.8428 | time/epoch=333.13s


[imdb][r=8] Epoch 6/10: 100%|██████████| 1250/1250 [05:21<00:00,  3.89it/s, loss=0.322] 


[imdb][r=8] Epoch 6 | train_loss=0.1944 | val_acc=0.8412 | val_f1=0.8430 | time/epoch=336.76s


[imdb][r=8] Epoch 7/10: 100%|██████████| 1250/1250 [05:18<00:00,  3.93it/s, loss=0.0956]


[imdb][r=8] Epoch 7 | train_loss=0.1697 | val_acc=0.8400 | val_f1=0.8390 | time/epoch=333.72s


[imdb][r=8] Epoch 8/10: 100%|██████████| 1250/1250 [05:16<00:00,  3.95it/s, loss=0.612] 


[imdb][r=8] Epoch 8 | train_loss=0.1567 | val_acc=0.8404 | val_f1=0.8403 | time/epoch=331.87s


[imdb][r=8] Epoch 9/10: 100%|██████████| 1250/1250 [05:16<00:00,  3.95it/s, loss=0.251] 


[imdb][r=8] Epoch 9 | train_loss=0.1452 | val_acc=0.8416 | val_f1=0.8420 | time/epoch=330.89s


[imdb][r=8] Epoch 10/10: 100%|██████████| 1250/1250 [28:57<00:00,  1.39s/it, loss=0.0934]   


[imdb][r=8] Epoch 10 | train_loss=0.1365 | val_acc=0.8416 | val_f1=0.8422 | time/epoch=264.35s
[imdb][r=8] Training completed! best_val_acc=0.8412, avg_time/epoch=328.50s, converge_epoch=1, trainable_params=25109762, trainable_ratio=37.1344%
[imdb][r=8] Final test accuracy: 0.8448, test_f1=0.8507

------------------------------
HiRA with rank r=16
------------------------------
[imdb][r=16] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=16] Epoch 1/10: 100%|██████████| 1250/1250 [29:10<00:00,  1.40s/it, loss=0.552]   


[imdb][r=16] Epoch 1 | train_loss=0.4907 | val_acc=0.8228 | val_f1=0.8200 | time/epoch=256.82s
[imdb][r=16] Saved best model with val_acc=0.8228 to ./results_hira/best_model_r16.pt


[imdb][r=16] Epoch 2/10: 100%|██████████| 1250/1250 [45:04<00:00,  2.16s/it, loss=0.542]    


[imdb][r=16] Epoch 2 | train_loss=0.3496 | val_acc=0.8332 | val_f1=0.8343 | time/epoch=252.57s
[imdb][r=16] Saved best model with val_acc=0.8332 to ./results_hira/best_model_r16.pt


[imdb][r=16] Epoch 3/10: 100%|██████████| 1250/1250 [37:40<00:00,  1.81s/it, loss=0.198]   


[imdb][r=16] Epoch 3 | train_loss=0.2992 | val_acc=0.8396 | val_f1=0.8381 | time/epoch=255.30s
[imdb][r=16] Saved best model with val_acc=0.8396 to ./results_hira/best_model_r16.pt


[imdb][r=16] Epoch 4/10: 100%|██████████| 1250/1250 [57:42<00:00,  2.77s/it, loss=0.142]    


[imdb][r=16] Epoch 4 | train_loss=0.2591 | val_acc=0.8392 | val_f1=0.8370 | time/epoch=249.20s


[imdb][r=16] Epoch 5/10: 100%|██████████| 1250/1250 [56:46<00:00,  2.73s/it, loss=0.307]   


[imdb][r=16] Epoch 5 | train_loss=0.2251 | val_acc=0.8448 | val_f1=0.8469 | time/epoch=251.03s
[imdb][r=16] Saved best model with val_acc=0.8448 to ./results_hira/best_model_r16.pt


[imdb][r=16] Epoch 6/10: 100%|██████████| 1250/1250 [31:53<00:00,  1.53s/it, loss=0.346]     


[imdb][r=16] Epoch 6 | train_loss=0.1971 | val_acc=0.8444 | val_f1=0.8467 | time/epoch=249.67s


[imdb][r=16] Epoch 7/10: 100%|██████████| 1250/1250 [53:09<00:00,  2.55s/it, loss=0.0767]   


[imdb][r=16] Epoch 7 | train_loss=0.1712 | val_acc=0.8440 | val_f1=0.8466 | time/epoch=253.67s


[imdb][r=16] Epoch 8/10: 100%|██████████| 1250/1250 [28:20<00:00,  1.36s/it, loss=0.361]    


[imdb][r=16] Epoch 8 | train_loss=0.1549 | val_acc=0.8432 | val_f1=0.8447 | time/epoch=253.54s


[imdb][r=16] Epoch 9/10: 100%|██████████| 1250/1250 [1:00:30<00:00,  2.90s/it, loss=0.151]    


[imdb][r=16] Epoch 9 | train_loss=0.1439 | val_acc=0.8444 | val_f1=0.8457 | time/epoch=248.47s


[imdb][r=16] Epoch 10/10: 100%|██████████| 1250/1250 [52:38<00:00,  2.53s/it, loss=0.289]     


[imdb][r=16] Epoch 10 | train_loss=0.1355 | val_acc=0.8448 | val_f1=0.8453 | time/epoch=248.98s
[imdb][r=16] Training completed! best_val_acc=0.8448, avg_time/epoch=251.93s, converge_epoch=1, trainable_params=25773314, trainable_ratio=37.7453%
[imdb][r=16] Final test accuracy: 0.8424, test_f1=0.8483

Result Summary Table:

| Rank | Trainable Params / Total | Ratio | Val F1 | Val Acc | Test F1 | Test Acc | Sparsity (<1e−3) | Train Time (s) |
| ---- | ------------------------ | ----- | ------ | ------- | ------- | -------- | ----------------- | -------------- |
| 1 | 24,777,986 / 67.3M | 36.82% | 0.8475 | 0.8460 | 0.8526 | 0.8472 | 1.77% | 3565.19 |
| 2 | 24,612,098 / 67.1M | 36.67% | 0.8513 | 0.8492 | 0.8523 | 0.8460 | 1.75% | 3328.82 |
| 3 | 25,109,762 / 67.6M | 37.13% | 0.8434 | 0.8412 | 0.8507 | 0.8448 | 1.80% | 3285.02 |
| 4 | 25,773,314 / 68.3M | 37.75% | 0.8469 | 0.8448 | 0.8483 | 0.8424 | 1.88% | 2519.27 |


Summary over ranks:
r=2: val_acc=0.8492, test_acc=0.8460, avg_time/epoch