In [2]:
pip install -r requirements.txt


Collecting torch>=2.0.0 (from -r requirements.txt (line 1))
  Downloading torch-2.9.1-cp314-cp314-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting transformers>=4.30.0 (from -r requirements.txt (line 2))
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
Collecting datasets (from -r requirements.txt (line 3))
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting peft (from -r requirements.txt (line 4))
  Downloading peft-0.18.0-py3-none-any.whl.metadata (14 kB)
Collecting accelerate (from -r requirements.txt (line 5))
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate (from -r requirements.txt (line 6))
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting scikit-learn (from -r requirements.txt (line 7))
  Downloading scikit_learn-1.7.2-cp314-cp314-macosx_12_0_arm64.whl.metadata (11 kB)
Collecting tqdm (from -r requirements.txt (line 8))
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57

In [1]:
import random
import numpy as np

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


NameError: name 'torch' is not defined

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import os

class HiRALayer(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 32,
        lora_alpha: int = 32,
        lora_dropout: float = 0.1,
    ):
        super().__init__()
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = lora_alpha / r
        
        self.lora_A = nn.Parameter(torch.zeros(in_features, r))
        self.lora_B = nn.Parameter(torch.randn(r, out_features))
        
        nn.init.zeros_(self.lora_A)
        nn.init.kaiming_uniform_(self.lora_B, a=0)
        
        self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0 else nn.Identity()
        
    def forward(self, x: torch.Tensor, W0: torch.Tensor) -> torch.Tensor:
        result = F.linear(x, W0)
        
        lora_update = self.lora_A @ self.lora_B  # [in_features, out_features]
        
        hadamard_update = W0.T * lora_update  # [in_features, out_features]
        
        result += self.lora_dropout(x @ hadamard_update) * self.scaling
        
        return result


class HiRALinear(nn.Module):
    def __init__(self, linear_layer: nn.Linear, r: int = 32, lora_alpha: int = 32):
        super().__init__()
        self.linear = linear_layer
        self.hira = HiRALayer(
            in_features=linear_layer.in_features,
            out_features=linear_layer.out_features,
            r=r,
            lora_alpha=lora_alpha
        )
        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.bias.requires_grad = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.hira(x, self.linear.weight)




In [3]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
import torch.nn as nn

def apply_hira_to_model(model, r=32, lora_alpha=32, target_modules=['q_lin', 'k_lin', 'v_lin', 'out_lin', 'ffn.lin1', 'ffn.lin2']):

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if any(target in name for target in target_modules):
                parent_name = '.'.join(name.split('.')[:-1])
                attr_name = name.split('.')[-1]
                
                if parent_name:
                    parent_module = model.get_submodule(parent_name)
                else:
                    parent_module = model
                
                hira_layer = HiRALinear(module, r=r, lora_alpha=lora_alpha)
                setattr(parent_module, attr_name, hira_layer)
                
                print(f"Applied HiRA to: {name}")
    
    return model

def get_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    return trainable_params, all_params, 100 * trainable_params / all_params


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

def load_sst2_data(tokenizer, max_length=128):
    """8:1:1 """
    raw = load_dataset("glue", "sst2")  # 有 train / validation / test[web:110]

    train_valid = raw["train"]
    train_valid = train_valid.shuffle(seed=42)
    n = len(train_valid)
    n_train = int(0.8 * n)
    n_val = int(0.1 * n)
    n_test = n - n_train - n_val

    train_dataset = train_valid.select(range(n_train))
    val_dataset   = train_valid.select(range(n_train, n_train + n_val))
    test_dataset  = train_valid.select(range(n_train + n_val, n))

    def preprocess_function(examples):
        enc = tokenizer(
            examples["sentence"],
            truncation=True,
            max_length=max_length
        )
        enc["labels"] = examples["label"]
        return enc

    train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
    val_dataset   = val_dataset.map(preprocess_function,   batched=True, remove_columns=val_dataset.column_names)
    test_dataset  = test_dataset.map(preprocess_function,  batched=True, remove_columns=test_dataset.column_names)

    return {
        "train": train_dataset,
        "validation": val_dataset,
        "test": test_dataset,
    }



def load_imdb_data(tokenizer, max_length=256):

    raw = load_dataset("imdb")  

    train_full = raw["train"]               
    train_full = train_full.shuffle(seed=42)
    n = len(train_full)
    n_train = int(0.8 * n)
    n_val = int(0.1 * n)
    n_test = n - n_train - n_val

    train_dataset = train_full.select(range(n_train))
    val_dataset   = train_full.select(range(n_train, n_train + n_val))
    test_dataset  = train_full.select(range(n_train + n_val, n))

    def preprocess_function(examples):
        enc = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
        )
        enc["labels"] = examples["label"]
        return enc

    train_dataset = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=train_dataset.column_names,
    )
    val_dataset = val_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=val_dataset.column_names,
    )
    test_dataset = test_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=test_dataset.column_names,
    )

    return {
        "train": train_dataset,
        "validation": val_dataset,
        "test": test_dataset,
    }



# After changing

In [5]:
import os
import time
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from tqdm import tqdm
import evaluate
from datasets import load_dataset




def get_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for p in model.parameters():
        num = p.numel()
        all_params += num
        if p.requires_grad:
            trainable_params += num
    percentage = 100.0 * trainable_params / all_params if all_params > 0 else 0.0
    return trainable_params, all_params, percentage


def get_model_sparsity(model, threshold: float = 1e-3) -> float:
    total_elems = 0
    small_elems = 0
    for p in model.parameters():
        if p is None:
            continue
        data = p.detach()
        total_elems += data.numel()
        small_elems += (data.abs() < threshold).sum().item()
    if total_elems == 0:
        return 0.0
    return small_elems / total_elems

def train_hira_model(
    dataset_name: str = "sst2",
    model_name: str = "distilbert-base-uncased",
    r: int = 32,
    lora_alpha: int = 32,
    num_epochs: int = 3,
    batch_size: int = 16,
    learning_rate: float = 2e-5,
    warmup_steps: int = 100,
    max_length: int = 128,
    output_dir: str = "./results_hira",
):

    device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"[{dataset_name}][r={r}] Using device: {device}")

    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        return_tensors="pt",
    )

    num_labels = 2
    model = DistilBertForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
    )
    model = apply_hira_to_model(model, r=r, lora_alpha=lora_alpha)
    model = model.to(device)

    trainable_params, all_params, percentage = get_trainable_parameters(model)
    print(
        f"[{dataset_name}][r={r}] Trainable params: {trainable_params:,} || "
        f"All params: {all_params:,} || Trainable%: {percentage:.4f}%"
    )

    if dataset_name == "sst2":
        dataset = load_sst2_data(tokenizer, max_length)
    elif dataset_name == "imdb":
        dataset = load_imdb_data(tokenizer, max_length)
    else:
        raise ValueError(f"Unknown dataset_name: {dataset_name}")

    train_dataloader = DataLoader(
        dataset["train"],
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator,
    )
    val_split_name = "validation" if "validation" in dataset else "test"
    val_dataloader = DataLoader(
        dataset[val_split_name],
        batch_size=batch_size,
        collate_fn=data_collator,
    )
    test_dataloader = DataLoader(
        dataset["test"],
        batch_size=batch_size,
        collate_fn=data_collator,
    )

    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=learning_rate,
    )
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    os.makedirs(output_dir, exist_ok=True)

    epoch_times = []
    total_train_time = 0.0
    best_val_accuracy = 0.0
    best_val_f1 = 0.0
    best_epoch = -1

    for epoch in range(num_epochs):
        start_t = time.perf_counter()

        model.train()
        total_loss = 0.0
        progress_bar = tqdm(
            train_dataloader,
            desc=f"[{dataset_name}][r={r}] Epoch {epoch + 1}/{num_epochs}",
        )

        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})

        avg_train_loss = total_loss / len(train_dataloader)

        acc_metric = evaluate.load("accuracy")
        f1_metric = evaluate.load("f1")
        model.eval()
        for batch in val_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            preds = torch.argmax(outputs.logits, dim=-1)
            acc_metric.add_batch(predictions=preds, references=batch["labels"])
            f1_metric.add_batch(predictions=preds, references=batch["labels"])

        val_acc = acc_metric.compute()["accuracy"]
        val_f1 = f1_metric.compute()["f1"]

        end_t = time.perf_counter()
        epoch_time = end_t - start_t
        epoch_times.append(epoch_time)
        total_train_time += epoch_time

        print(
            f"[{dataset_name}][r={r}] Epoch {epoch + 1} | "
            f"train_loss={avg_train_loss:.4f} | "
            f"val_acc={val_acc:.4f} | "
            f"val_f1={val_f1:.4f} | "
            f"time/epoch={epoch_time:.2f}s"
        )

        if val_f1 > best_val_f1:
            best_val_accuracy = val_acc
            best_val_f1 = val_f1
            if best_epoch == -1:
                best_epoch = epoch + 1

            save_path = os.path.join(output_dir, f"best_model_r{r}.pt")
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "accuracy": best_val_accuracy,
                },
                save_path,
            )
            print(
                f"[{dataset_name}][r={r}] Saved best model with "
                f"val_acc={best_val_accuracy:.4f} to {save_path}"
            )

    avg_time_per_epoch = sum(epoch_times) / len(epoch_times)
    print(
        f"[{dataset_name}][r={r}] Training completed! "
        f"best_val_acc={best_val_accuracy:.4f}, "
        f"avg_time/epoch={avg_time_per_epoch:.2f}s, "
        f"converge_epoch={best_epoch}, "
        f"trainable_params={trainable_params}, "
        f"trainable_ratio={percentage:.4f}%"
    )

    acc_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")
    model.eval()
    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        preds = torch.argmax(outputs.logits, dim=-1)
        acc_metric.add_batch(predictions=preds, references=batch["labels"])
        f1_metric.add_batch(predictions=preds, references=batch["labels"])

    test_acc = acc_metric.compute()["accuracy"]
    test_f1 = f1_metric.compute()["f1"]
    print(f"[{dataset_name}][r={r}] Final test accuracy: {test_acc:.4f}, test_f1={test_f1:.4f}")

    sparsity = get_model_sparsity(model, threshold=1e-3)

    return {
        "dataset": dataset_name,
        "r": r,
        "best_val_acc": best_val_accuracy,
        "best_val_f1": best_val_f1,
        "test_acc": test_acc,
        "test_f1": test_f1,
        "avg_time_per_epoch": avg_time_per_epoch,
        "total_train_time": total_train_time,
        "converge_epoch": best_epoch,
        "trainable_params": trainable_params,
        "total_params": all_params,
        "trainable_ratio": percentage,   
        "sparsity": sparsity,            
    }


def format_int_with_commas(x: int) -> str:
    return f"{x:,}"


def generate_summary_table(results):
    results_sorted = sorted(
        results,
        key=lambda d: d["test_acc"],
        reverse=True,
    )

    header = (
        "| Rank | Trainable Params / Total | Ratio | Val F1 | Val Acc | "
        "Test F1 | Test Acc | Sparsity (<1e−3) | Train Time (s) |\n"
    )
    sep = (
        "| ---- | ------------------------ | ----- | ------ | ------- | "
        "------- | -------- | ----------------- | -------------- |\n"
    )

    lines = [header, sep]
    for rank, res in enumerate(results_sorted, start=1):
        trainable = format_int_with_commas(res["trainable_params"])
        total = f"{res['total_params']/1e6:.1f}M"
        ratio = f"{res['trainable_ratio']:.2f}%"
        val_f1 = f"{res['best_val_f1']:.4f}"
        val_acc = f"{res['best_val_acc']:.4f}"
        test_f1 = f"{res['test_f1']:.4f}"
        test_acc = f"{res['test_acc']:.4f}"
        sparsity = f"{res['sparsity']*100:.2f}%"
        train_time = f"{res['total_train_time']:.2f}"

        line = (
            f"| {rank} | {trainable} / {total} | {ratio} | "
            f"{val_f1} | {val_acc} | {test_f1} | {test_acc} | "
            f"{sparsity} | {train_time} |\n"
        )
        lines.append(line)

    table_md = "".join(lines)
    print("\nResult Summary Table:\n")
    print(table_md)


In [6]:

print("=" * 50)
print("Training HiRA on SST-2 with different ranks")
print("=" * 50)

results = []
for r in [2, 4, 8, 16]:
    print("\n" + "-" * 30)
    print(f"HiRA with rank r={r}")
    print("-" * 30)
    res = train_hira_model(
        dataset_name="sst2",
        model_name="distilbert-base-uncased",
        r=r,
        lora_alpha=32,
        num_epochs=10,        
        batch_size=16,
        learning_rate=2e-5,  
        warmup_steps=100,
        max_length=128,
        output_dir="./results_hira",
    )
    results.append(res)

generate_summary_table(results)

print("\nSummary over ranks:")
for res in results:
    print(
        f"r={res['r']}: "
        f"val_acc={res['best_val_acc']:.4f}, "
        f"test_acc={res['test_acc']:.4f}, "
        f"avg_time/epoch={res['avg_time_per_epoch']:.2f}s, "
        f"converge_epoch={res['converge_epoch']}, "
        f"trainable={res['trainable_params']} ({res['trainable_ratio']:.4f}%)"
    )



Training HiRA on SST-2 with different ranks

------------------------------
HiRA with rank r=2
------------------------------
[sst2][r=2] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[sst2][r=2] Epoch 1/10: 100%|██████████| 3368/3368 [04:55<00:00, 11.40it/s, loss=0.373] 
Downloading builder script: 6.79kB [00:00, 9.61MB/s]


[sst2][r=2] Epoch 1 | train_loss=0.3480 | val_acc=0.9024 | val_f1=0.9131 | time/epoch=319.12s
[sst2][r=2] Saved best model with val_acc=0.9024 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 2/10: 100%|██████████| 3368/3368 [05:47<00:00,  9.70it/s, loss=0.102] 


[sst2][r=2] Epoch 2 | train_loss=0.2189 | val_acc=0.9168 | val_f1=0.9262 | time/epoch=361.93s
[sst2][r=2] Saved best model with val_acc=0.9168 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 3/10: 100%|██████████| 3368/3368 [05:19<00:00, 10.55it/s, loss=0.028]  


[sst2][r=2] Epoch 3 | train_loss=0.1712 | val_acc=0.9265 | val_f1=0.9342 | time/epoch=334.22s
[sst2][r=2] Saved best model with val_acc=0.9265 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 4/10: 100%|██████████| 3368/3368 [05:04<00:00, 11.06it/s, loss=0.393]  


[sst2][r=2] Epoch 4 | train_loss=0.1414 | val_acc=0.9317 | val_f1=0.9390 | time/epoch=320.76s
[sst2][r=2] Saved best model with val_acc=0.9317 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 5/10: 100%|██████████| 3368/3368 [05:17<00:00, 10.61it/s, loss=0.591]  


[sst2][r=2] Epoch 5 | train_loss=0.1203 | val_acc=0.9330 | val_f1=0.9401 | time/epoch=333.44s
[sst2][r=2] Saved best model with val_acc=0.9330 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 6/10: 100%|██████████| 3368/3368 [05:15<00:00, 10.69it/s, loss=0.00773]


[sst2][r=2] Epoch 6 | train_loss=0.1079 | val_acc=0.9344 | val_f1=0.9411 | time/epoch=331.25s
[sst2][r=2] Saved best model with val_acc=0.9344 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 7/10: 100%|██████████| 3368/3368 [05:18<00:00, 10.58it/s, loss=0.349]  


[sst2][r=2] Epoch 7 | train_loss=0.0973 | val_acc=0.9348 | val_f1=0.9416 | time/epoch=334.46s
[sst2][r=2] Saved best model with val_acc=0.9348 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 8/10: 100%|██████████| 3368/3368 [05:20<00:00, 10.51it/s, loss=0.0243] 


[sst2][r=2] Epoch 8 | train_loss=0.0884 | val_acc=0.9361 | val_f1=0.9430 | time/epoch=336.72s
[sst2][r=2] Saved best model with val_acc=0.9361 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 9/10: 100%|██████████| 3368/3368 [05:13<00:00, 10.75it/s, loss=0.00744]


[sst2][r=2] Epoch 9 | train_loss=0.0850 | val_acc=0.9367 | val_f1=0.9435 | time/epoch=329.15s
[sst2][r=2] Saved best model with val_acc=0.9367 to ./results_hira/best_model_r2.pt


[sst2][r=2] Epoch 10/10: 100%|██████████| 3368/3368 [05:23<00:00, 10.40it/s, loss=0.0977] 


[sst2][r=2] Epoch 10 | train_loss=0.0818 | val_acc=0.9373 | val_f1=0.9441 | time/epoch=340.47s
[sst2][r=2] Saved best model with val_acc=0.9373 to ./results_hira/best_model_r2.pt
[sst2][r=2] Training completed! best_val_acc=0.9373, avg_time/epoch=334.15s, converge_epoch=1, trainable_params=24612098, trainable_ratio=36.6683%
[sst2][r=2] Final test accuracy: 0.9382, test_f1=0.9458

------------------------------
HiRA with rank r=4
------------------------------
[sst2][r=4] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[sst2][r=4] Epoch 1/10: 100%|██████████| 3368/3368 [05:35<00:00, 10.05it/s, loss=0.101] 


[sst2][r=4] Epoch 1 | train_loss=0.3418 | val_acc=0.9004 | val_f1=0.9111 | time/epoch=353.02s
[sst2][r=4] Saved best model with val_acc=0.9004 to ./results_hira/best_model_r4.pt


[sst2][r=4] Epoch 2/10: 100%|██████████| 3368/3368 [05:24<00:00, 10.38it/s, loss=0.0792] 


[sst2][r=4] Epoch 2 | train_loss=0.2167 | val_acc=0.9182 | val_f1=0.9273 | time/epoch=341.24s
[sst2][r=4] Saved best model with val_acc=0.9182 to ./results_hira/best_model_r4.pt


[sst2][r=4] Epoch 3/10: 100%|██████████| 3368/3368 [05:34<00:00, 10.06it/s, loss=0.012]  


[sst2][r=4] Epoch 3 | train_loss=0.1692 | val_acc=0.9266 | val_f1=0.9347 | time/epoch=351.44s
[sst2][r=4] Saved best model with val_acc=0.9266 to ./results_hira/best_model_r4.pt


[sst2][r=4] Epoch 4/10: 100%|██████████| 3368/3368 [05:19<00:00, 10.54it/s, loss=0.316]  


[sst2][r=4] Epoch 4 | train_loss=0.1401 | val_acc=0.9315 | val_f1=0.9387 | time/epoch=335.78s
[sst2][r=4] Saved best model with val_acc=0.9315 to ./results_hira/best_model_r4.pt


[sst2][r=4] Epoch 5/10: 100%|██████████| 3368/3368 [05:16<00:00, 10.65it/s, loss=0.0206] 


[sst2][r=4] Epoch 5 | train_loss=0.1215 | val_acc=0.9335 | val_f1=0.9403 | time/epoch=332.47s
[sst2][r=4] Saved best model with val_acc=0.9335 to ./results_hira/best_model_r4.pt


[sst2][r=4] Epoch 6/10: 100%|██████████| 3368/3368 [05:14<00:00, 10.70it/s, loss=0.109]  


[sst2][r=4] Epoch 6 | train_loss=0.1070 | val_acc=0.9350 | val_f1=0.9418 | time/epoch=330.67s
[sst2][r=4] Saved best model with val_acc=0.9350 to ./results_hira/best_model_r4.pt


[sst2][r=4] Epoch 7/10: 100%|██████████| 3368/3368 [05:16<00:00, 10.64it/s, loss=0.00793]


[sst2][r=4] Epoch 7 | train_loss=0.0966 | val_acc=0.9358 | val_f1=0.9426 | time/epoch=332.51s
[sst2][r=4] Saved best model with val_acc=0.9358 to ./results_hira/best_model_r4.pt


[sst2][r=4] Epoch 8/10: 100%|██████████| 3368/3368 [05:13<00:00, 10.76it/s, loss=0.0219] 


[sst2][r=4] Epoch 8 | train_loss=0.0904 | val_acc=0.9375 | val_f1=0.9442 | time/epoch=329.21s
[sst2][r=4] Saved best model with val_acc=0.9375 to ./results_hira/best_model_r4.pt


[sst2][r=4] Epoch 9/10: 100%|██████████| 3368/3368 [05:19<00:00, 10.54it/s, loss=0.00822]


[sst2][r=4] Epoch 9 | train_loss=0.0849 | val_acc=0.9366 | val_f1=0.9433 | time/epoch=336.32s


[sst2][r=4] Epoch 10/10: 100%|██████████| 3368/3368 [05:24<00:00, 10.37it/s, loss=0.00689]


[sst2][r=4] Epoch 10 | train_loss=0.0808 | val_acc=0.9370 | val_f1=0.9438 | time/epoch=341.41s
[sst2][r=4] Training completed! best_val_acc=0.9375, avg_time/epoch=338.41s, converge_epoch=1, trainable_params=24777986, trainable_ratio=36.8244%
[sst2][r=4] Final test accuracy: 0.9382, test_f1=0.9459

------------------------------
HiRA with rank r=8
------------------------------
[sst2][r=8] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[sst2][r=8] Epoch 1/10: 100%|██████████| 3368/3368 [30:20<00:00,  1.85it/s, loss=0.121]    


[sst2][r=8] Epoch 1 | train_loss=0.3449 | val_acc=0.9004 | val_f1=0.9097 | time/epoch=326.80s
[sst2][r=8] Saved best model with val_acc=0.9004 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 2/10: 100%|██████████| 3368/3368 [20:04<00:00,  2.80it/s, loss=0.36]     


[sst2][r=8] Epoch 2 | train_loss=0.2182 | val_acc=0.9176 | val_f1=0.9270 | time/epoch=304.32s
[sst2][r=8] Saved best model with val_acc=0.9176 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 3/10: 100%|██████████| 3368/3368 [05:39<00:00,  9.92it/s, loss=0.118]  


[sst2][r=8] Epoch 3 | train_loss=0.1714 | val_acc=0.9241 | val_f1=0.9321 | time/epoch=355.02s
[sst2][r=8] Saved best model with val_acc=0.9241 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 4/10: 100%|██████████| 3368/3368 [05:04<00:00, 11.04it/s, loss=0.0443] 


[sst2][r=8] Epoch 4 | train_loss=0.1411 | val_acc=0.9275 | val_f1=0.9359 | time/epoch=321.24s
[sst2][r=8] Saved best model with val_acc=0.9275 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 5/10: 100%|██████████| 3368/3368 [05:08<00:00, 10.91it/s, loss=0.00425]


[sst2][r=8] Epoch 5 | train_loss=0.1213 | val_acc=0.9321 | val_f1=0.9392 | time/epoch=324.63s
[sst2][r=8] Saved best model with val_acc=0.9321 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 6/10: 100%|██████████| 3368/3368 [05:18<00:00, 10.58it/s, loss=0.11]   


[sst2][r=8] Epoch 6 | train_loss=0.1089 | val_acc=0.9347 | val_f1=0.9418 | time/epoch=335.23s
[sst2][r=8] Saved best model with val_acc=0.9347 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 7/10: 100%|██████████| 3368/3368 [05:22<00:00, 10.46it/s, loss=0.744]  


[sst2][r=8] Epoch 7 | train_loss=0.0963 | val_acc=0.9353 | val_f1=0.9424 | time/epoch=340.79s
[sst2][r=8] Saved best model with val_acc=0.9353 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 8/10: 100%|██████████| 3368/3368 [05:25<00:00, 10.33it/s, loss=0.0421] 


[sst2][r=8] Epoch 8 | train_loss=0.0888 | val_acc=0.9361 | val_f1=0.9431 | time/epoch=342.71s
[sst2][r=8] Saved best model with val_acc=0.9361 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 9/10: 100%|██████████| 3368/3368 [05:22<00:00, 10.46it/s, loss=0.000541]


[sst2][r=8] Epoch 9 | train_loss=0.0851 | val_acc=0.9361 | val_f1=0.9432 | time/epoch=338.74s
[sst2][r=8] Saved best model with val_acc=0.9361 to ./results_hira/best_model_r8.pt


[sst2][r=8] Epoch 10/10: 100%|██████████| 3368/3368 [05:19<00:00, 10.55it/s, loss=0.00243]


[sst2][r=8] Epoch 10 | train_loss=0.0815 | val_acc=0.9363 | val_f1=0.9432 | time/epoch=335.66s
[sst2][r=8] Saved best model with val_acc=0.9363 to ./results_hira/best_model_r8.pt
[sst2][r=8] Training completed! best_val_acc=0.9363, avg_time/epoch=332.51s, converge_epoch=1, trainable_params=25109762, trainable_ratio=37.1344%
[sst2][r=8] Final test accuracy: 0.9387, test_f1=0.9462

------------------------------
HiRA with rank r=16
------------------------------
[sst2][r=16] Using device: mps


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 2dd09c47-c9e4-4c30-8eb7-5aad09983146)')' thrown while requesting HEAD https://huggingface.co/distilbert-base-uncased/resolve/main/tokenizer_config.json
Retrying in 1s [Retry 1/5].
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[sst2][r=16] Epoch 1/10: 100%|██████████| 3368/3368 [05:06<00:00, 10.99it/s, loss=0.996] 


[sst2][r=16] Epoch 1 | train_loss=0.3456 | val_acc=0.9018 | val_f1=0.9128 | time/epoch=323.77s
[sst2][r=16] Saved best model with val_acc=0.9018 to ./results_hira/best_model_r16.pt


[sst2][r=16] Epoch 2/10: 100%|██████████| 3368/3368 [05:26<00:00, 10.33it/s, loss=0.672] 


[sst2][r=16] Epoch 2 | train_loss=0.2187 | val_acc=0.9174 | val_f1=0.9267 | time/epoch=342.33s
[sst2][r=16] Saved best model with val_acc=0.9174 to ./results_hira/best_model_r16.pt


[sst2][r=16] Epoch 3/10: 100%|██████████| 3368/3368 [05:24<00:00, 10.39it/s, loss=0.0777] 


[sst2][r=16] Epoch 3 | train_loss=0.1707 | val_acc=0.9275 | val_f1=0.9355 | time/epoch=341.72s
[sst2][r=16] Saved best model with val_acc=0.9275 to ./results_hira/best_model_r16.pt


[sst2][r=16] Epoch 4/10: 100%|██████████| 3368/3368 [05:30<00:00, 10.20it/s, loss=0.0636] 


[sst2][r=16] Epoch 4 | train_loss=0.1414 | val_acc=0.9298 | val_f1=0.9375 | time/epoch=347.25s
[sst2][r=16] Saved best model with val_acc=0.9298 to ./results_hira/best_model_r16.pt


[sst2][r=16] Epoch 5/10: 100%|██████████| 3368/3368 [05:24<00:00, 10.38it/s, loss=0.131]  


[sst2][r=16] Epoch 5 | train_loss=0.1212 | val_acc=0.9335 | val_f1=0.9405 | time/epoch=340.76s
[sst2][r=16] Saved best model with val_acc=0.9335 to ./results_hira/best_model_r16.pt


[sst2][r=16] Epoch 6/10: 100%|██████████| 3368/3368 [05:19<00:00, 10.53it/s, loss=0.00691]


[sst2][r=16] Epoch 6 | train_loss=0.1071 | val_acc=0.9341 | val_f1=0.9408 | time/epoch=336.11s
[sst2][r=16] Saved best model with val_acc=0.9341 to ./results_hira/best_model_r16.pt


[sst2][r=16] Epoch 7/10: 100%|██████████| 3368/3368 [05:22<00:00, 10.44it/s, loss=0.0108] 


[sst2][r=16] Epoch 7 | train_loss=0.0972 | val_acc=0.9356 | val_f1=0.9426 | time/epoch=340.30s
[sst2][r=16] Saved best model with val_acc=0.9356 to ./results_hira/best_model_r16.pt


[sst2][r=16] Epoch 8/10: 100%|██████████| 3368/3368 [05:27<00:00, 10.30it/s, loss=0.0396] 


[sst2][r=16] Epoch 8 | train_loss=0.0911 | val_acc=0.9361 | val_f1=0.9429 | time/epoch=344.36s
[sst2][r=16] Saved best model with val_acc=0.9361 to ./results_hira/best_model_r16.pt


[sst2][r=16] Epoch 9/10: 100%|██████████| 3368/3368 [05:25<00:00, 10.33it/s, loss=0.0782] 


[sst2][r=16] Epoch 9 | train_loss=0.0848 | val_acc=0.9360 | val_f1=0.9429 | time/epoch=342.16s


[sst2][r=16] Epoch 10/10: 100%|██████████| 3368/3368 [05:32<00:00, 10.14it/s, loss=0.173]  


[sst2][r=16] Epoch 10 | train_loss=0.0830 | val_acc=0.9367 | val_f1=0.9436 | time/epoch=348.33s
[sst2][r=16] Saved best model with val_acc=0.9367 to ./results_hira/best_model_r16.pt
[sst2][r=16] Training completed! best_val_acc=0.9367, avg_time/epoch=340.71s, converge_epoch=1, trainable_params=25773314, trainable_ratio=37.7453%
[sst2][r=16] Final test accuracy: 0.9381, test_f1=0.9458

Result Summary Table:

| Rank | Trainable Params / Total | Ratio | Val F1 | Val Acc | Test F1 | Test Acc | Sparsity (<1e−3) | Train Time (s) |
| ---- | ------------------------ | ----- | ------ | ------- | ------- | -------- | ----------------- | -------------- |
| 1 | 25,109,762 / 67.6M | 37.13% | 0.9432 | 0.9363 | 0.9462 | 0.9387 | 1.79% | 3325.13 |
| 2 | 24,612,098 / 67.1M | 36.67% | 0.9441 | 0.9373 | 0.9458 | 0.9382 | 1.74% | 3341.53 |
| 3 | 24,777,986 / 67.3M | 36.82% | 0.9442 | 0.9375 | 0.9459 | 0.9382 | 1.76% | 3384.06 |
| 4 | 25,773,314 / 68.3M | 37.75% | 0.9436 | 0.9367 | 0.9458 | 0.9381 | 1.84% 

In [7]:
generate_summary_table(results)



Result Summary Table:

| Rank | Trainable Params / Total | Ratio | Val F1 | Val Acc | Test F1 | Test Acc | Sparsity (<1e−3) | Train Time (s) |
| ---- | ------------------------ | ----- | ------ | ------- | ------- | -------- | ----------------- | -------------- |
| 1 | 25,109,762 / 67.6M | 37.13% | 0.9432 | 0.9363 | 0.9462 | 0.9387 | 1.79% | 3325.13 |
| 2 | 24,612,098 / 67.1M | 36.67% | 0.9441 | 0.9373 | 0.9458 | 0.9382 | 1.74% | 3341.53 |
| 3 | 24,777,986 / 67.3M | 36.82% | 0.9442 | 0.9375 | 0.9459 | 0.9382 | 1.76% | 3384.06 |
| 4 | 25,773,314 / 68.3M | 37.75% | 0.9436 | 0.9367 | 0.9458 | 0.9381 | 1.84% | 3407.09 |



In [8]:

print("=" * 50)
print("Training HiRA on SST-2 with different ranks")
print("=" * 50)

imdb_results = []
for r in [2, 4, 8, 16]:
    print("\n" + "-" * 30)
    print(f"HiRA with rank r={r}")
    print("-" * 30)
    res = train_hira_model(
        dataset_name="imdb",
        model_name="distilbert-base-uncased",
        r=r,
        lora_alpha=32,
        num_epochs=10,        
        batch_size=16,
        learning_rate=2e-5,  
        warmup_steps=100,
        max_length=128,
        output_dir="./results_hira",
    )
    imdb_results.append(res)

generate_summary_table(imdb_results)

print("\nSummary over ranks:")
for res in imdb_results:
    print(
        f"r={res['r']}: "
        f"val_acc={res['best_val_acc']:.4f}, "
        f"test_acc={res['test_acc']:.4f}, "
        f"avg_time/epoch={res['avg_time_per_epoch']:.2f}s, "
        f"converge_epoch={res['converge_epoch']}, "
        f"trainable={res['trainable_params']} ({res['trainable_ratio']:.4f}%)"
    )





Training HiRA on SST-2 with different ranks

------------------------------
HiRA with rank r=2
------------------------------
[imdb][r=2] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

Map: 100%|██████████| 20000/20000 [00:26<00:00, 749.97 examples/s]
Map: 100%|██████████| 2500/2500 [00:03<00:00, 770.62 examples/s]
Map: 100%|██████████| 2500/2500 [00:03<00:00, 758.20 examples/s]
[imdb][r=2] Epoch 1/10: 100%|██████████| 1250/1250 [04:33<00:00,  4.58it/s, loss=0.399]


[imdb][r=2] Epoch 1 | train_loss=0.4822 | val_acc=0.8212 | val_f1=0.8158 | time/epoch=299.35s
[imdb][r=2] Saved best model with val_acc=0.8212 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 2/10: 100%|██████████| 1250/1250 [06:04<00:00,  3.43it/s, loss=0.431] 


[imdb][r=2] Epoch 2 | train_loss=0.3514 | val_acc=0.8300 | val_f1=0.8266 | time/epoch=378.34s
[imdb][r=2] Saved best model with val_acc=0.8300 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 3/10: 100%|██████████| 1250/1250 [05:34<00:00,  3.73it/s, loss=0.652] 


[imdb][r=2] Epoch 3 | train_loss=0.3002 | val_acc=0.8372 | val_f1=0.8379 | time/epoch=350.08s
[imdb][r=2] Saved best model with val_acc=0.8372 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 4/10: 100%|██████████| 1250/1250 [04:59<00:00,  4.17it/s, loss=0.182] 


[imdb][r=2] Epoch 4 | train_loss=0.2598 | val_acc=0.8428 | val_f1=0.8411 | time/epoch=315.13s
[imdb][r=2] Saved best model with val_acc=0.8428 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 5/10: 100%|██████████| 1250/1250 [05:20<00:00,  3.90it/s, loss=0.0897]


[imdb][r=2] Epoch 5 | train_loss=0.2271 | val_acc=0.8472 | val_f1=0.8489 | time/epoch=335.56s
[imdb][r=2] Saved best model with val_acc=0.8472 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 6/10: 100%|██████████| 1250/1250 [05:15<00:00,  3.96it/s, loss=0.163] 


[imdb][r=2] Epoch 6 | train_loss=0.2013 | val_acc=0.8472 | val_f1=0.8470 | time/epoch=331.46s


[imdb][r=2] Epoch 7/10: 100%|██████████| 1250/1250 [05:15<00:00,  3.96it/s, loss=0.119] 


[imdb][r=2] Epoch 7 | train_loss=0.1740 | val_acc=0.8452 | val_f1=0.8435 | time/epoch=330.83s


[imdb][r=2] Epoch 8/10: 100%|██████████| 1250/1250 [05:11<00:00,  4.01it/s, loss=0.49]  


[imdb][r=2] Epoch 8 | train_loss=0.1591 | val_acc=0.8480 | val_f1=0.8475 | time/epoch=327.27s


[imdb][r=2] Epoch 9/10: 100%|██████████| 1250/1250 [05:12<00:00,  4.00it/s, loss=0.432] 


[imdb][r=2] Epoch 9 | train_loss=0.1454 | val_acc=0.8504 | val_f1=0.8508 | time/epoch=327.03s
[imdb][r=2] Saved best model with val_acc=0.8504 to ./results_hira/best_model_r2.pt


[imdb][r=2] Epoch 10/10: 100%|██████████| 1250/1250 [05:17<00:00,  3.94it/s, loss=0.167] 


[imdb][r=2] Epoch 10 | train_loss=0.1406 | val_acc=0.8492 | val_f1=0.8513 | time/epoch=333.76s
[imdb][r=2] Saved best model with val_acc=0.8492 to ./results_hira/best_model_r2.pt
[imdb][r=2] Training completed! best_val_acc=0.8492, avg_time/epoch=332.88s, converge_epoch=1, trainable_params=24612098, trainable_ratio=36.6683%
[imdb][r=2] Final test accuracy: 0.8460, test_f1=0.8523

------------------------------
HiRA with rank r=4
------------------------------
[imdb][r=4] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=4] Epoch 1/10: 100%|██████████| 1250/1250 [05:53<00:00,  3.53it/s, loss=0.287]


[imdb][r=4] Epoch 1 | train_loss=0.4814 | val_acc=0.8192 | val_f1=0.8167 | time/epoch=371.73s
[imdb][r=4] Saved best model with val_acc=0.8192 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 2/10: 100%|██████████| 1250/1250 [06:08<00:00,  3.39it/s, loss=0.383] 


[imdb][r=4] Epoch 2 | train_loss=0.3510 | val_acc=0.8328 | val_f1=0.8321 | time/epoch=386.08s
[imdb][r=4] Saved best model with val_acc=0.8328 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 3/10: 100%|██████████| 1250/1250 [06:05<00:00,  3.42it/s, loss=0.173] 


[imdb][r=4] Epoch 3 | train_loss=0.3018 | val_acc=0.8400 | val_f1=0.8403 | time/epoch=382.79s
[imdb][r=4] Saved best model with val_acc=0.8400 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 4/10: 100%|██████████| 1250/1250 [05:47<00:00,  3.59it/s, loss=0.311] 


[imdb][r=4] Epoch 4 | train_loss=0.2607 | val_acc=0.8392 | val_f1=0.8376 | time/epoch=364.35s


[imdb][r=4] Epoch 5/10: 100%|██████████| 1250/1250 [05:29<00:00,  3.80it/s, loss=0.504] 


[imdb][r=4] Epoch 5 | train_loss=0.2266 | val_acc=0.8428 | val_f1=0.8424 | time/epoch=344.92s
[imdb][r=4] Saved best model with val_acc=0.8428 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 6/10: 100%|██████████| 1250/1250 [05:16<00:00,  3.95it/s, loss=0.131] 


[imdb][r=4] Epoch 6 | train_loss=0.1976 | val_acc=0.8412 | val_f1=0.8411 | time/epoch=332.43s


[imdb][r=4] Epoch 7/10: 100%|██████████| 1250/1250 [05:24<00:00,  3.85it/s, loss=0.151] 


[imdb][r=4] Epoch 7 | train_loss=0.1764 | val_acc=0.8456 | val_f1=0.8468 | time/epoch=340.96s
[imdb][r=4] Saved best model with val_acc=0.8456 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 8/10: 100%|██████████| 1250/1250 [05:33<00:00,  3.75it/s, loss=0.166] 


[imdb][r=4] Epoch 8 | train_loss=0.1588 | val_acc=0.8460 | val_f1=0.8475 | time/epoch=349.01s
[imdb][r=4] Saved best model with val_acc=0.8460 to ./results_hira/best_model_r4.pt


[imdb][r=4] Epoch 9/10: 100%|██████████| 1250/1250 [05:28<00:00,  3.81it/s, loss=0.17]   


[imdb][r=4] Epoch 9 | train_loss=0.1462 | val_acc=0.8448 | val_f1=0.8459 | time/epoch=344.04s


[imdb][r=4] Epoch 10/10: 100%|██████████| 1250/1250 [05:32<00:00,  3.76it/s, loss=0.0692]


[imdb][r=4] Epoch 10 | train_loss=0.1393 | val_acc=0.8444 | val_f1=0.8455 | time/epoch=348.89s
[imdb][r=4] Training completed! best_val_acc=0.8460, avg_time/epoch=356.52s, converge_epoch=1, trainable_params=24777986, trainable_ratio=36.8244%
[imdb][r=4] Final test accuracy: 0.8472, test_f1=0.8526

------------------------------
HiRA with rank r=8
------------------------------
[imdb][r=8] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=8] Epoch 1/10: 100%|██████████| 1250/1250 [05:31<00:00,  3.77it/s, loss=0.361] 


[imdb][r=8] Epoch 1 | train_loss=0.4819 | val_acc=0.8172 | val_f1=0.8165 | time/epoch=346.88s
[imdb][r=8] Saved best model with val_acc=0.8172 to ./results_hira/best_model_r8.pt


[imdb][r=8] Epoch 2/10: 100%|██████████| 1250/1250 [05:22<00:00,  3.88it/s, loss=0.467] 


[imdb][r=8] Epoch 2 | train_loss=0.3487 | val_acc=0.8308 | val_f1=0.8314 | time/epoch=337.78s
[imdb][r=8] Saved best model with val_acc=0.8308 to ./results_hira/best_model_r8.pt


[imdb][r=8] Epoch 3/10: 100%|██████████| 1250/1250 [05:19<00:00,  3.92it/s, loss=0.344] 


[imdb][r=8] Epoch 3 | train_loss=0.2968 | val_acc=0.8360 | val_f1=0.8373 | time/epoch=334.93s
[imdb][r=8] Saved best model with val_acc=0.8360 to ./results_hira/best_model_r8.pt


[imdb][r=8] Epoch 4/10: 100%|██████████| 1250/1250 [05:19<00:00,  3.91it/s, loss=0.442] 


[imdb][r=8] Epoch 4 | train_loss=0.2576 | val_acc=0.8412 | val_f1=0.8434 | time/epoch=334.71s
[imdb][r=8] Saved best model with val_acc=0.8412 to ./results_hira/best_model_r8.pt


[imdb][r=8] Epoch 5/10: 100%|██████████| 1250/1250 [05:17<00:00,  3.93it/s, loss=0.346] 


[imdb][r=8] Epoch 5 | train_loss=0.2241 | val_acc=0.8420 | val_f1=0.8428 | time/epoch=333.13s


[imdb][r=8] Epoch 6/10: 100%|██████████| 1250/1250 [05:21<00:00,  3.89it/s, loss=0.322] 


[imdb][r=8] Epoch 6 | train_loss=0.1944 | val_acc=0.8412 | val_f1=0.8430 | time/epoch=336.76s


[imdb][r=8] Epoch 7/10: 100%|██████████| 1250/1250 [05:18<00:00,  3.93it/s, loss=0.0956]


[imdb][r=8] Epoch 7 | train_loss=0.1697 | val_acc=0.8400 | val_f1=0.8390 | time/epoch=333.72s


[imdb][r=8] Epoch 8/10: 100%|██████████| 1250/1250 [05:16<00:00,  3.95it/s, loss=0.612] 


[imdb][r=8] Epoch 8 | train_loss=0.1567 | val_acc=0.8404 | val_f1=0.8403 | time/epoch=331.87s


[imdb][r=8] Epoch 9/10: 100%|██████████| 1250/1250 [05:16<00:00,  3.95it/s, loss=0.251] 


[imdb][r=8] Epoch 9 | train_loss=0.1452 | val_acc=0.8416 | val_f1=0.8420 | time/epoch=330.89s


[imdb][r=8] Epoch 10/10: 100%|██████████| 1250/1250 [28:57<00:00,  1.39s/it, loss=0.0934]   


[imdb][r=8] Epoch 10 | train_loss=0.1365 | val_acc=0.8416 | val_f1=0.8422 | time/epoch=264.35s
[imdb][r=8] Training completed! best_val_acc=0.8412, avg_time/epoch=328.50s, converge_epoch=1, trainable_params=25109762, trainable_ratio=37.1344%
[imdb][r=8] Final test accuracy: 0.8448, test_f1=0.8507

------------------------------
HiRA with rank r=16
------------------------------
[imdb][r=16] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=16] Epoch 1/10: 100%|██████████| 1250/1250 [29:10<00:00,  1.40s/it, loss=0.552]   


[imdb][r=16] Epoch 1 | train_loss=0.4907 | val_acc=0.8228 | val_f1=0.8200 | time/epoch=256.82s
[imdb][r=16] Saved best model with val_acc=0.8228 to ./results_hira/best_model_r16.pt


[imdb][r=16] Epoch 2/10: 100%|██████████| 1250/1250 [45:04<00:00,  2.16s/it, loss=0.542]    


[imdb][r=16] Epoch 2 | train_loss=0.3496 | val_acc=0.8332 | val_f1=0.8343 | time/epoch=252.57s
[imdb][r=16] Saved best model with val_acc=0.8332 to ./results_hira/best_model_r16.pt


[imdb][r=16] Epoch 3/10: 100%|██████████| 1250/1250 [37:40<00:00,  1.81s/it, loss=0.198]   


[imdb][r=16] Epoch 3 | train_loss=0.2992 | val_acc=0.8396 | val_f1=0.8381 | time/epoch=255.30s
[imdb][r=16] Saved best model with val_acc=0.8396 to ./results_hira/best_model_r16.pt


[imdb][r=16] Epoch 4/10: 100%|██████████| 1250/1250 [57:42<00:00,  2.77s/it, loss=0.142]    


[imdb][r=16] Epoch 4 | train_loss=0.2591 | val_acc=0.8392 | val_f1=0.8370 | time/epoch=249.20s


[imdb][r=16] Epoch 5/10: 100%|██████████| 1250/1250 [56:46<00:00,  2.73s/it, loss=0.307]   


[imdb][r=16] Epoch 5 | train_loss=0.2251 | val_acc=0.8448 | val_f1=0.8469 | time/epoch=251.03s
[imdb][r=16] Saved best model with val_acc=0.8448 to ./results_hira/best_model_r16.pt


[imdb][r=16] Epoch 6/10: 100%|██████████| 1250/1250 [31:53<00:00,  1.53s/it, loss=0.346]     


[imdb][r=16] Epoch 6 | train_loss=0.1971 | val_acc=0.8444 | val_f1=0.8467 | time/epoch=249.67s


[imdb][r=16] Epoch 7/10: 100%|██████████| 1250/1250 [53:09<00:00,  2.55s/it, loss=0.0767]   


[imdb][r=16] Epoch 7 | train_loss=0.1712 | val_acc=0.8440 | val_f1=0.8466 | time/epoch=253.67s


[imdb][r=16] Epoch 8/10: 100%|██████████| 1250/1250 [28:20<00:00,  1.36s/it, loss=0.361]    


[imdb][r=16] Epoch 8 | train_loss=0.1549 | val_acc=0.8432 | val_f1=0.8447 | time/epoch=253.54s


[imdb][r=16] Epoch 9/10: 100%|██████████| 1250/1250 [1:00:30<00:00,  2.90s/it, loss=0.151]    


[imdb][r=16] Epoch 9 | train_loss=0.1439 | val_acc=0.8444 | val_f1=0.8457 | time/epoch=248.47s


[imdb][r=16] Epoch 10/10: 100%|██████████| 1250/1250 [52:38<00:00,  2.53s/it, loss=0.289]     


[imdb][r=16] Epoch 10 | train_loss=0.1355 | val_acc=0.8448 | val_f1=0.8453 | time/epoch=248.98s
[imdb][r=16] Training completed! best_val_acc=0.8448, avg_time/epoch=251.93s, converge_epoch=1, trainable_params=25773314, trainable_ratio=37.7453%
[imdb][r=16] Final test accuracy: 0.8424, test_f1=0.8483

Result Summary Table:

| Rank | Trainable Params / Total | Ratio | Val F1 | Val Acc | Test F1 | Test Acc | Sparsity (<1e−3) | Train Time (s) |
| ---- | ------------------------ | ----- | ------ | ------- | ------- | -------- | ----------------- | -------------- |
| 1 | 24,777,986 / 67.3M | 36.82% | 0.8475 | 0.8460 | 0.8526 | 0.8472 | 1.77% | 3565.19 |
| 2 | 24,612,098 / 67.1M | 36.67% | 0.8513 | 0.8492 | 0.8523 | 0.8460 | 1.75% | 3328.82 |
| 3 | 25,109,762 / 67.6M | 37.13% | 0.8434 | 0.8412 | 0.8507 | 0.8448 | 1.80% | 3285.02 |
| 4 | 25,773,314 / 68.3M | 37.75% | 0.8469 | 0.8448 | 0.8483 | 0.8424 | 1.88% | 2519.27 |


Summary over ranks:
r=2: val_acc=0.8492, test_acc=0.8460, avg_time/epoch

# Before changing

In [None]:
import os
import time
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from tqdm import tqdm
import evaluate


def train_hira_model(
    dataset_name: str = "sst2",           
    model_name: str = "distilbert-base-uncased",
    r: int = 32,                        
    lora_alpha: int = 32,
    num_epochs: int = 10,              
    batch_size: int = 16,
    learning_rate: float = 1e-3,
    warmup_steps: int = 100,
    max_length: int = 128,
    output_dir: str = "./results_hira",
):

    device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"[{dataset_name}][r={r}] Using device: {device}")

    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        return_tensors="pt",
    )

    num_labels = 2
    model = DistilBertForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
    )
    model = apply_hira_to_model(model, r=r, lora_alpha=lora_alpha)
    model = model.to(device)

    trainable_params, all_params, percentage = get_trainable_parameters(model)
    print(
        f"[{dataset_name}][r={r}] Trainable params: {trainable_params:,} || "
        f"All params: {all_params:,} || Trainable%: {percentage:.4f}%"
    )

    if dataset_name == "sst2":
        dataset = load_sst2_data(tokenizer, max_length)
    elif dataset_name == "imdb":
        dataset = load_imdb_data(tokenizer, max_length)
    else:
        raise ValueError(f"Unknown dataset_name: {dataset_name}")

    train_dataloader = DataLoader(
        dataset["train"],
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator,
    )
    val_dataloader = DataLoader(
        dataset["validation"],
        batch_size=batch_size,
        collate_fn=data_collator,
    )
    test_dataloader = DataLoader(
        dataset["test"],
        batch_size=batch_size,
        collate_fn=data_collator,
    )

    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=learning_rate,
    )
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    os.makedirs(output_dir, exist_ok=True)

    epoch_times = []
    best_val_accuracy = 0.0
    best_epoch = -1

    for epoch in range(num_epochs):
        start_t = time.perf_counter()

        model.train()
        total_loss = 0.0
        progress_bar = tqdm(
            train_dataloader,
            desc=f"[{dataset_name}][r={r}] Epoch {epoch + 1}/{num_epochs}",
        )

        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})

        avg_train_loss = total_loss / len(train_dataloader)

        model.eval()
        metric = evaluate.load("accuracy")
        for batch in val_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            preds = torch.argmax(outputs.logits, dim=-1)
            metric.add_batch(predictions=preds, references=batch["labels"])
        val_acc = metric.compute()["accuracy"]

        end_t = time.perf_counter()
        epoch_time = end_t - start_t
        epoch_times.append(epoch_time)

        print(
            f"[{dataset_name}][r={r}] Epoch {epoch + 1} | "
            f"train_loss={avg_train_loss:.4f} | "
            f"val_acc={val_acc:.4f} | "
            f"time/epoch={epoch_time:.2f}s"
        )

        if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc
            if best_epoch == -1:
                best_epoch = epoch + 1 

            save_path = os.path.join(output_dir, f"best_model_r{r}.pt")
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "accuracy": best_val_accuracy,
                },
                save_path,
            )
            print(
                f"[{dataset_name}][r={r}] Saved best model with "
                f"val_acc={best_val_accuracy:.4f} to {save_path}"
            )

    avg_time_per_epoch = sum(epoch_times) / len(epoch_times)
    print(
        f"[{dataset_name}][r={r}] Training completed! "
        f"best_val_acc={best_val_accuracy:.4f}, "
        f"avg_time/epoch={avg_time_per_epoch:.2f}s, "
        f"converge_epoch={best_epoch}, "
        f"trainable_params={trainable_params}, "
        f"trainable_ratio={percentage:.4f}%"
    )

    model.eval()
    metric = evaluate.load("accuracy")
    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        preds = torch.argmax(outputs.logits, dim=-1)
        metric.add_batch(predictions=preds, references=batch["labels"])
    test_acc = metric.compute()["accuracy"]
    print(f"[{dataset_name}][r={r}] Final test accuracy: {test_acc:.4f}")

    return {
        "dataset": dataset_name,
        "r": r,
        "best_val_acc": best_val_accuracy,
        "test_acc": test_acc,
        "avg_time_per_epoch": avg_time_per_epoch,
        "converge_epoch": best_epoch,
        "trainable_params": trainable_params,
        "trainable_ratio": percentage,
    }


In [30]:
print("=" * 50)
print("Training HiRA on SST-2 with different ranks")
print("=" * 50)

results = []
for r in [2, 4, 8, 16]:
    print("\n" + "-" * 30)
    print(f"HiRA with rank r={r}")
    print("-" * 30)
    res = train_hira_model(
        dataset_name='sst2',
        r=r,
        num_epochs=10,
        batch_size=16,
        learning_rate=1e-3,
        max_length=128,
        output_dir=f'./results_hira_sst2_r{r}',
    )
    results.append(res)

print("\nSummary over ranks:")
for res in results:
    print(
        f"r={res['r']}: "
        f"val_acc={res['best_val_acc']:.4f}, "
        f"test_acc={res['test_acc']:.4f}, "
        f"avg_time/epoch={res['avg_time_per_epoch']:.2f}s, "
        f"converge_epoch={res['converge_epoch']}, "
        f"trainable={res['trainable_params']} ({res['trainable_ratio']:.4f}%)"
    )


Training HiRA on SST-2 with different ranks

------------------------------
HiRA with rank r=2
------------------------------
[sst2][r=2] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

Map: 100%|██████████| 53879/53879 [00:04<00:00, 12149.05 examples/s]
Map: 100%|██████████| 6734/6734 [00:00<00:00, 11892.88 examples/s]
Map: 100%|██████████| 6736/6736 [00:00<00:00, 12082.25 examples/s]
[sst2][r=2] Epoch 1/10: 100%|██████████| 3368/3368 [05:29<00:00, 10.23it/s, loss=0.106] 


[sst2][r=2] Epoch 1 | train_loss=0.2740 | val_acc=0.9180 | time/epoch=353.79s
[sst2][r=2] Saved best model with val_acc=0.9180 to ./results_hira_sst2_r2/best_model_r2.pt


[sst2][r=2] Epoch 2/10: 100%|██████████| 3368/3368 [05:41<00:00,  9.86it/s, loss=0.0809] 


[sst2][r=2] Epoch 2 | train_loss=0.1318 | val_acc=0.9252 | time/epoch=358.61s
[sst2][r=2] Saved best model with val_acc=0.9252 to ./results_hira_sst2_r2/best_model_r2.pt


[sst2][r=2] Epoch 3/10: 100%|██████████| 3368/3368 [05:39<00:00,  9.93it/s, loss=0.00314] 


[sst2][r=2] Epoch 3 | train_loss=0.0769 | val_acc=0.9275 | time/epoch=357.89s
[sst2][r=2] Saved best model with val_acc=0.9275 to ./results_hira_sst2_r2/best_model_r2.pt


[sst2][r=2] Epoch 4/10: 100%|██████████| 3368/3368 [05:29<00:00, 10.23it/s, loss=0.0408]  


[sst2][r=2] Epoch 4 | train_loss=0.0452 | val_acc=0.9275 | time/epoch=346.72s


[sst2][r=2] Epoch 5/10: 100%|██████████| 3368/3368 [05:28<00:00, 10.25it/s, loss=2.57e-5] 


[sst2][r=2] Epoch 5 | train_loss=0.0282 | val_acc=0.9298 | time/epoch=345.83s
[sst2][r=2] Saved best model with val_acc=0.9298 to ./results_hira_sst2_r2/best_model_r2.pt


[sst2][r=2] Epoch 6/10: 100%|██████████| 3368/3368 [05:30<00:00, 10.18it/s, loss=0.000182]


[sst2][r=2] Epoch 6 | train_loss=0.0180 | val_acc=0.9302 | time/epoch=348.16s
[sst2][r=2] Saved best model with val_acc=0.9302 to ./results_hira_sst2_r2/best_model_r2.pt


[sst2][r=2] Epoch 7/10: 100%|██████████| 3368/3368 [05:30<00:00, 10.19it/s, loss=0.000142]


[sst2][r=2] Epoch 7 | train_loss=0.0112 | val_acc=0.9289 | time/epoch=347.81s


[sst2][r=2] Epoch 8/10: 100%|██████████| 3368/3368 [05:39<00:00,  9.92it/s, loss=0.000101]


[sst2][r=2] Epoch 8 | train_loss=0.0061 | val_acc=0.9308 | time/epoch=357.74s
[sst2][r=2] Saved best model with val_acc=0.9308 to ./results_hira_sst2_r2/best_model_r2.pt


[sst2][r=2] Epoch 9/10: 100%|██████████| 3368/3368 [05:44<00:00,  9.79it/s, loss=3.44e-6] 


[sst2][r=2] Epoch 9 | train_loss=0.0047 | val_acc=0.9308 | time/epoch=362.16s


[sst2][r=2] Epoch 10/10: 100%|██████████| 3368/3368 [05:44<00:00,  9.79it/s, loss=1.7e-8]  


[sst2][r=2] Epoch 10 | train_loss=0.0029 | val_acc=0.9317 | time/epoch=362.18s
[sst2][r=2] Saved best model with val_acc=0.9317 to ./results_hira_sst2_r2/best_model_r2.pt
[sst2][r=2] Training completed! best_val_acc=0.9317, avg_time/epoch=354.09s, converge_epoch=1, trainable_params=24612098, trainable_ratio=36.6683%
[sst2][r=2] Final test accuracy: 0.9314

------------------------------
HiRA with rank r=4
------------------------------
[sst2][r=4] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[sst2][r=4] Epoch 1/10: 100%|██████████| 3368/3368 [05:42<00:00,  9.83it/s, loss=0.278] 


[sst2][r=4] Epoch 1 | train_loss=0.2755 | val_acc=0.9208 | time/epoch=360.97s
[sst2][r=4] Saved best model with val_acc=0.9208 to ./results_hira_sst2_r4/best_model_r4.pt


[sst2][r=4] Epoch 2/10: 100%|██████████| 3368/3368 [05:48<00:00,  9.67it/s, loss=0.0178] 


[sst2][r=4] Epoch 2 | train_loss=0.1334 | val_acc=0.9287 | time/epoch=366.02s
[sst2][r=4] Saved best model with val_acc=0.9287 to ./results_hira_sst2_r4/best_model_r4.pt


[sst2][r=4] Epoch 3/10: 100%|██████████| 3368/3368 [05:47<00:00,  9.70it/s, loss=0.667]   


[sst2][r=4] Epoch 3 | train_loss=0.0809 | val_acc=0.9253 | time/epoch=364.86s


[sst2][r=4] Epoch 4/10: 100%|██████████| 3368/3368 [05:59<00:00,  9.37it/s, loss=0.0154]  


[sst2][r=4] Epoch 4 | train_loss=0.0478 | val_acc=0.9278 | time/epoch=376.82s


[sst2][r=4] Epoch 5/10: 100%|██████████| 3368/3368 [05:50<00:00,  9.60it/s, loss=0.00536] 


[sst2][r=4] Epoch 5 | train_loss=0.0288 | val_acc=0.9299 | time/epoch=369.61s
[sst2][r=4] Saved best model with val_acc=0.9299 to ./results_hira_sst2_r4/best_model_r4.pt


[sst2][r=4] Epoch 6/10: 100%|██████████| 3368/3368 [06:11<00:00,  9.07it/s, loss=0.00525] 


[sst2][r=4] Epoch 6 | train_loss=0.0183 | val_acc=0.9308 | time/epoch=389.24s
[sst2][r=4] Saved best model with val_acc=0.9308 to ./results_hira_sst2_r4/best_model_r4.pt


[sst2][r=4] Epoch 7/10: 100%|██████████| 3368/3368 [05:56<00:00,  9.45it/s, loss=1.59e-5] 


[sst2][r=4] Epoch 7 | train_loss=0.0105 | val_acc=0.9320 | time/epoch=373.95s
[sst2][r=4] Saved best model with val_acc=0.9320 to ./results_hira_sst2_r4/best_model_r4.pt


[sst2][r=4] Epoch 8/10: 100%|██████████| 3368/3368 [05:47<00:00,  9.68it/s, loss=0.644]   


[sst2][r=4] Epoch 8 | train_loss=0.0070 | val_acc=0.9326 | time/epoch=365.87s
[sst2][r=4] Saved best model with val_acc=0.9326 to ./results_hira_sst2_r4/best_model_r4.pt


[sst2][r=4] Epoch 9/10: 100%|██████████| 3368/3368 [05:39<00:00,  9.91it/s, loss=2.09e-6] 


[sst2][r=4] Epoch 9 | train_loss=0.0051 | val_acc=0.9307 | time/epoch=358.45s


[sst2][r=4] Epoch 10/10: 100%|██████████| 3368/3368 [05:51<00:00,  9.59it/s, loss=0]       


[sst2][r=4] Epoch 10 | train_loss=0.0032 | val_acc=0.9318 | time/epoch=371.67s
[sst2][r=4] Training completed! best_val_acc=0.9326, avg_time/epoch=369.75s, converge_epoch=1, trainable_params=24777986, trainable_ratio=36.8244%
[sst2][r=4] Final test accuracy: 0.9317

------------------------------
HiRA with rank r=8
------------------------------
[sst2][r=8] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[sst2][r=8] Epoch 1/10: 100%|██████████| 3368/3368 [05:53<00:00,  9.54it/s, loss=0.015] 


[sst2][r=8] Epoch 1 | train_loss=0.2795 | val_acc=0.9182 | time/epoch=371.01s
[sst2][r=8] Saved best model with val_acc=0.9182 to ./results_hira_sst2_r8/best_model_r8.pt


[sst2][r=8] Epoch 2/10: 100%|██████████| 3368/3368 [05:52<00:00,  9.56it/s, loss=0.0939] 


[sst2][r=8] Epoch 2 | train_loss=0.1370 | val_acc=0.9311 | time/epoch=370.40s
[sst2][r=8] Saved best model with val_acc=0.9311 to ./results_hira_sst2_r8/best_model_r8.pt


[sst2][r=8] Epoch 3/10: 100%|██████████| 3368/3368 [05:50<00:00,  9.61it/s, loss=0.000691]


[sst2][r=8] Epoch 3 | train_loss=0.0794 | val_acc=0.9263 | time/epoch=368.53s


[sst2][r=8] Epoch 4/10: 100%|██████████| 3368/3368 [05:51<00:00,  9.59it/s, loss=0.00608] 


[sst2][r=8] Epoch 4 | train_loss=0.0473 | val_acc=0.9283 | time/epoch=369.30s


[sst2][r=8] Epoch 5/10: 100%|██████████| 3368/3368 [05:47<00:00,  9.69it/s, loss=0.0308]  


[sst2][r=8] Epoch 5 | train_loss=0.0297 | val_acc=0.9269 | time/epoch=365.72s


[sst2][r=8] Epoch 6/10: 100%|██████████| 3368/3368 [05:50<00:00,  9.60it/s, loss=0.154]   


[sst2][r=8] Epoch 6 | train_loss=0.0180 | val_acc=0.9281 | time/epoch=369.10s


[sst2][r=8] Epoch 7/10: 100%|██████████| 3368/3368 [05:50<00:00,  9.62it/s, loss=1.22e-5] 


[sst2][r=8] Epoch 7 | train_loss=0.0120 | val_acc=0.9274 | time/epoch=368.07s


[sst2][r=8] Epoch 8/10: 100%|██████████| 3368/3368 [06:00<00:00,  9.35it/s, loss=0.00605] 


[sst2][r=8] Epoch 8 | train_loss=0.0065 | val_acc=0.9296 | time/epoch=378.31s


[sst2][r=8] Epoch 9/10: 100%|██████████| 3368/3368 [05:55<00:00,  9.47it/s, loss=0]       


[sst2][r=8] Epoch 9 | train_loss=0.0045 | val_acc=0.9296 | time/epoch=373.80s


[sst2][r=8] Epoch 10/10: 100%|██████████| 3368/3368 [05:54<00:00,  9.50it/s, loss=0]       


[sst2][r=8] Epoch 10 | train_loss=0.0034 | val_acc=0.9289 | time/epoch=372.78s
[sst2][r=8] Training completed! best_val_acc=0.9311, avg_time/epoch=370.70s, converge_epoch=1, trainable_params=25109762, trainable_ratio=37.1344%
[sst2][r=8] Final test accuracy: 0.9348

------------------------------
HiRA with rank r=16
------------------------------
[sst2][r=16] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[sst2][r=16] Epoch 1/10: 100%|██████████| 3368/3368 [05:53<00:00,  9.54it/s, loss=0.116] 


[sst2][r=16] Epoch 1 | train_loss=0.2743 | val_acc=0.9214 | time/epoch=371.52s
[sst2][r=16] Saved best model with val_acc=0.9214 to ./results_hira_sst2_r16/best_model_r16.pt


[sst2][r=16] Epoch 2/10: 100%|██████████| 3368/3368 [05:45<00:00,  9.75it/s, loss=0.0526]  


[sst2][r=16] Epoch 2 | train_loss=0.1329 | val_acc=0.9246 | time/epoch=363.69s
[sst2][r=16] Saved best model with val_acc=0.9246 to ./results_hira_sst2_r16/best_model_r16.pt


[sst2][r=16] Epoch 3/10: 100%|██████████| 3368/3368 [05:51<00:00,  9.57it/s, loss=0.0117]  


[sst2][r=16] Epoch 3 | train_loss=0.0775 | val_acc=0.9298 | time/epoch=369.64s
[sst2][r=16] Saved best model with val_acc=0.9298 to ./results_hira_sst2_r16/best_model_r16.pt


[sst2][r=16] Epoch 4/10: 100%|██████████| 3368/3368 [05:58<00:00,  9.40it/s, loss=0.00043] 


[sst2][r=16] Epoch 4 | train_loss=0.0456 | val_acc=0.9299 | time/epoch=377.20s
[sst2][r=16] Saved best model with val_acc=0.9299 to ./results_hira_sst2_r16/best_model_r16.pt


[sst2][r=16] Epoch 5/10: 100%|██████████| 3368/3368 [05:56<00:00,  9.44it/s, loss=4.45e-5] 


[sst2][r=16] Epoch 5 | train_loss=0.0278 | val_acc=0.9305 | time/epoch=375.78s
[sst2][r=16] Saved best model with val_acc=0.9305 to ./results_hira_sst2_r16/best_model_r16.pt


[sst2][r=16] Epoch 6/10: 100%|██████████| 3368/3368 [06:08<00:00,  9.14it/s, loss=0.0781]  


[sst2][r=16] Epoch 6 | train_loss=0.0170 | val_acc=0.9307 | time/epoch=388.21s
[sst2][r=16] Saved best model with val_acc=0.9307 to ./results_hira_sst2_r16/best_model_r16.pt


[sst2][r=16] Epoch 7/10: 100%|██████████| 3368/3368 [06:20<00:00,  8.86it/s, loss=3.58e-5] 


[sst2][r=16] Epoch 7 | train_loss=0.0114 | val_acc=0.9295 | time/epoch=399.63s


[sst2][r=16] Epoch 8/10: 100%|██████████| 3368/3368 [06:18<00:00,  8.90it/s, loss=0.0324]  


[sst2][r=16] Epoch 8 | train_loss=0.0063 | val_acc=0.9315 | time/epoch=398.23s
[sst2][r=16] Saved best model with val_acc=0.9315 to ./results_hira_sst2_r16/best_model_r16.pt


[sst2][r=16] Epoch 9/10: 100%|██████████| 3368/3368 [06:18<00:00,  8.89it/s, loss=0.00283] 


[sst2][r=16] Epoch 9 | train_loss=0.0047 | val_acc=0.9307 | time/epoch=398.27s


[sst2][r=16] Epoch 10/10: 100%|██████████| 3368/3368 [06:17<00:00,  8.93it/s, loss=0]       


[sst2][r=16] Epoch 10 | train_loss=0.0028 | val_acc=0.9311 | time/epoch=396.70s
[sst2][r=16] Training completed! best_val_acc=0.9315, avg_time/epoch=383.89s, converge_epoch=1, trainable_params=25773314, trainable_ratio=37.7453%
[sst2][r=16] Final test accuracy: 0.9313

Summary over ranks:
r=2: val_acc=0.9317, test_acc=0.9314, avg_time/epoch=354.09s, converge_epoch=1, trainable=24612098 (36.6683%)
r=4: val_acc=0.9326, test_acc=0.9317, avg_time/epoch=369.75s, converge_epoch=1, trainable=24777986 (36.8244%)
r=8: val_acc=0.9311, test_acc=0.9348, avg_time/epoch=370.70s, converge_epoch=1, trainable=25109762 (37.1344%)
r=16: val_acc=0.9315, test_acc=0.9313, avg_time/epoch=383.89s, converge_epoch=1, trainable=25773314 (37.7453%)


In [33]:
for r in [2, 4, 8, 16]:
    res = train_hira_model(
        dataset_name="imdb",
        r=r,
        num_epochs=10,
        batch_size=8,
        learning_rate=1e-3,
        max_length=256,
        output_dir=f"./results_hira_imdb_r{r}",
    )


[imdb][r=2] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

Map: 100%|██████████| 20000/20000 [00:27<00:00, 740.22 examples/s]
Map: 100%|██████████| 2500/2500 [00:03<00:00, 699.49 examples/s]
Map: 100%|██████████| 2500/2500 [00:03<00:00, 682.79 examples/s]
[imdb][r=2] Epoch 1/10: 100%|██████████| 2500/2500 [11:02<00:00,  3.77it/s, loss=0.441]  


[imdb][r=2] Epoch 1 | train_loss=0.3759 | val_acc=0.8572 | time/epoch=690.73s
[imdb][r=2] Saved best model with val_acc=0.8572 to ./results_hira_imdb_r2/best_model_r2.pt


[imdb][r=2] Epoch 2/10: 100%|██████████| 2500/2500 [10:58<00:00,  3.79it/s, loss=0.211]  


[imdb][r=2] Epoch 2 | train_loss=0.1880 | val_acc=0.8400 | time/epoch=690.20s


[imdb][r=2] Epoch 3/10: 100%|██████████| 2500/2500 [10:46<00:00,  3.87it/s, loss=0.0151]  


[imdb][r=2] Epoch 3 | train_loss=0.0668 | val_acc=0.8476 | time/epoch=676.94s


[imdb][r=2] Epoch 4/10: 100%|██████████| 2500/2500 [10:57<00:00,  3.80it/s, loss=0.0379]  


[imdb][r=2] Epoch 4 | train_loss=0.0277 | val_acc=0.8464 | time/epoch=687.56s


[imdb][r=2] Epoch 5/10: 100%|██████████| 2500/2500 [11:14<00:00,  3.71it/s, loss=0.0274]  


[imdb][r=2] Epoch 5 | train_loss=0.0136 | val_acc=0.8468 | time/epoch=706.28s


[imdb][r=2] Epoch 6/10: 100%|██████████| 2500/2500 [11:22<00:00,  3.67it/s, loss=1.91e-6] 


[imdb][r=2] Epoch 6 | train_loss=0.0057 | val_acc=0.8472 | time/epoch=713.00s


[imdb][r=2] Epoch 7/10: 100%|██████████| 2500/2500 [10:59<00:00,  3.79it/s, loss=3.13e-7] 


[imdb][r=2] Epoch 7 | train_loss=0.0030 | val_acc=0.8476 | time/epoch=690.18s


[imdb][r=2] Epoch 8/10: 100%|██████████| 2500/2500 [10:55<00:00,  3.82it/s, loss=0]       


[imdb][r=2] Epoch 8 | train_loss=0.0008 | val_acc=0.8496 | time/epoch=685.85s


[imdb][r=2] Epoch 9/10: 100%|██████████| 2500/2500 [10:50<00:00,  3.84it/s, loss=0]       


[imdb][r=2] Epoch 9 | train_loss=0.0002 | val_acc=0.8484 | time/epoch=681.26s


[imdb][r=2] Epoch 10/10: 100%|██████████| 2500/2500 [10:50<00:00,  3.84it/s, loss=2.98e-8] 


[imdb][r=2] Epoch 10 | train_loss=0.0000 | val_acc=0.8496 | time/epoch=680.70s
[imdb][r=2] Training completed! best_val_acc=0.8572, avg_time/epoch=690.27s, converge_epoch=1, trainable_params=24612098, trainable_ratio=36.6683%
[imdb][r=2] Final test accuracy: 0.8572
[imdb][r=4] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=4] Epoch 1/10: 100%|██████████| 2500/2500 [10:50<00:00,  3.85it/s, loss=0.261]  


[imdb][r=4] Epoch 1 | train_loss=0.3785 | val_acc=0.8560 | time/epoch=680.38s
[imdb][r=4] Saved best model with val_acc=0.8560 to ./results_hira_imdb_r4/best_model_r4.pt


[imdb][r=4] Epoch 2/10: 100%|██████████| 2500/2500 [10:49<00:00,  3.85it/s, loss=0.28]   


[imdb][r=4] Epoch 2 | train_loss=0.1879 | val_acc=0.8544 | time/epoch=680.36s


[imdb][r=4] Epoch 3/10: 100%|██████████| 2500/2500 [10:50<00:00,  3.84it/s, loss=0.000566]


[imdb][r=4] Epoch 3 | train_loss=0.0686 | val_acc=0.8580 | time/epoch=681.21s
[imdb][r=4] Saved best model with val_acc=0.8580 to ./results_hira_imdb_r4/best_model_r4.pt


[imdb][r=4] Epoch 4/10: 100%|██████████| 2500/2500 [10:49<00:00,  3.85it/s, loss=0.000998]


[imdb][r=4] Epoch 4 | train_loss=0.0296 | val_acc=0.8532 | time/epoch=680.05s


[imdb][r=4] Epoch 5/10: 100%|██████████| 2500/2500 [10:51<00:00,  3.84it/s, loss=0.000314]


[imdb][r=4] Epoch 5 | train_loss=0.0137 | val_acc=0.8504 | time/epoch=681.54s


[imdb][r=4] Epoch 6/10: 100%|██████████| 2500/2500 [10:48<00:00,  3.85it/s, loss=0.000556]


[imdb][r=4] Epoch 6 | train_loss=0.0068 | val_acc=0.8468 | time/epoch=678.96s


[imdb][r=4] Epoch 7/10: 100%|██████████| 2500/2500 [10:50<00:00,  3.84it/s, loss=0.0064]  


[imdb][r=4] Epoch 7 | train_loss=0.0036 | val_acc=0.8480 | time/epoch=681.47s


[imdb][r=4] Epoch 8/10: 100%|██████████| 2500/2500 [10:50<00:00,  3.84it/s, loss=4.47e-8] 


[imdb][r=4] Epoch 8 | train_loss=0.0008 | val_acc=0.8540 | time/epoch=681.05s


[imdb][r=4] Epoch 9/10: 100%|██████████| 2500/2500 [10:50<00:00,  3.85it/s, loss=0]       


[imdb][r=4] Epoch 9 | train_loss=0.0000 | val_acc=0.8548 | time/epoch=680.99s


[imdb][r=4] Epoch 10/10: 100%|██████████| 2500/2500 [10:51<00:00,  3.83it/s, loss=0]       


[imdb][r=4] Epoch 10 | train_loss=0.0000 | val_acc=0.8544 | time/epoch=682.23s
[imdb][r=4] Training completed! best_val_acc=0.8580, avg_time/epoch=680.82s, converge_epoch=1, trainable_params=24777986, trainable_ratio=36.8244%
[imdb][r=4] Final test accuracy: 0.8604
[imdb][r=8] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=8] Epoch 1/10: 100%|██████████| 2500/2500 [10:50<00:00,  3.84it/s, loss=0.813] 


[imdb][r=8] Epoch 1 | train_loss=0.3872 | val_acc=0.8640 | time/epoch=681.09s
[imdb][r=8] Saved best model with val_acc=0.8640 to ./results_hira_imdb_r8/best_model_r8.pt


[imdb][r=8] Epoch 2/10: 100%|██████████| 2500/2500 [10:51<00:00,  3.84it/s, loss=0.456]  


[imdb][r=8] Epoch 2 | train_loss=0.1892 | val_acc=0.8560 | time/epoch=681.37s


[imdb][r=8] Epoch 3/10: 100%|██████████| 2500/2500 [10:51<00:00,  3.84it/s, loss=0.329]   


[imdb][r=8] Epoch 3 | train_loss=0.0687 | val_acc=0.8484 | time/epoch=685.98s


[imdb][r=8] Epoch 4/10: 100%|██████████| 2500/2500 [10:48<00:00,  3.85it/s, loss=0.000294]


[imdb][r=8] Epoch 4 | train_loss=0.0280 | val_acc=0.8424 | time/epoch=678.54s


[imdb][r=8] Epoch 5/10: 100%|██████████| 2500/2500 [10:49<00:00,  3.85it/s, loss=0.000146]


[imdb][r=8] Epoch 5 | train_loss=0.0154 | val_acc=0.8424 | time/epoch=679.28s


[imdb][r=8] Epoch 6/10: 100%|██████████| 2500/2500 [10:48<00:00,  3.86it/s, loss=1.8e-6]  


[imdb][r=8] Epoch 6 | train_loss=0.0071 | val_acc=0.8428 | time/epoch=677.99s


[imdb][r=8] Epoch 7/10: 100%|██████████| 2500/2500 [10:51<00:00,  3.84it/s, loss=1.68e-5] 


[imdb][r=8] Epoch 7 | train_loss=0.0027 | val_acc=0.8480 | time/epoch=681.88s


[imdb][r=8] Epoch 8/10: 100%|██████████| 2500/2500 [10:47<00:00,  3.86it/s, loss=2.98e-8] 


[imdb][r=8] Epoch 8 | train_loss=0.0010 | val_acc=0.8384 | time/epoch=677.75s


[imdb][r=8] Epoch 9/10: 100%|██████████| 2500/2500 [10:55<00:00,  3.82it/s, loss=0]       


[imdb][r=8] Epoch 9 | train_loss=0.0002 | val_acc=0.8440 | time/epoch=686.61s


[imdb][r=8] Epoch 10/10: 100%|██████████| 2500/2500 [10:57<00:00,  3.80it/s, loss=0]       


[imdb][r=8] Epoch 10 | train_loss=0.0000 | val_acc=0.8444 | time/epoch=687.36s
[imdb][r=8] Training completed! best_val_acc=0.8640, avg_time/epoch=681.78s, converge_epoch=1, trainable_params=25109762, trainable_ratio=37.1344%
[imdb][r=8] Final test accuracy: 0.8552
[imdb][r=16] Using device: mps


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Applied HiRA to: distilbert.transformer.layer.0.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.0.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.0.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.1.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.1.attention.out_lin
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin1
Applied HiRA to: distilbert.transformer.layer.1.ffn.lin2
Applied HiRA to: distilbert.transformer.layer.2.attention.q_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.k_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.v_lin
Applied HiRA to: distilbert.transformer.layer.2.attention.out_li

[imdb][r=16] Epoch 1/10: 100%|██████████| 2500/2500 [10:51<00:00,  3.84it/s, loss=0.101] 


[imdb][r=16] Epoch 1 | train_loss=0.3840 | val_acc=0.8572 | time/epoch=681.52s
[imdb][r=16] Saved best model with val_acc=0.8572 to ./results_hira_imdb_r16/best_model_r16.pt


[imdb][r=16] Epoch 2/10: 100%|██████████| 2500/2500 [10:54<00:00,  3.82it/s, loss=0.117]  


[imdb][r=16] Epoch 2 | train_loss=0.1917 | val_acc=0.8540 | time/epoch=685.14s


[imdb][r=16] Epoch 3/10: 100%|██████████| 2500/2500 [11:13<00:00,  3.71it/s, loss=0.264]   


[imdb][r=16] Epoch 3 | train_loss=0.0736 | val_acc=0.8392 | time/epoch=704.23s


[imdb][r=16] Epoch 4/10: 100%|██████████| 2500/2500 [11:47<00:00,  3.53it/s, loss=0.428]   


[imdb][r=16] Epoch 4 | train_loss=0.0334 | val_acc=0.8264 | time/epoch=742.05s


[imdb][r=16] Epoch 5/10: 100%|██████████| 2500/2500 [12:08<00:00,  3.43it/s, loss=0.00231] 


[imdb][r=16] Epoch 5 | train_loss=0.0131 | val_acc=0.8376 | time/epoch=766.39s


[imdb][r=16] Epoch 6/10: 100%|██████████| 2500/2500 [12:20<00:00,  3.38it/s, loss=1.28e-6] 


[imdb][r=16] Epoch 6 | train_loss=0.0066 | val_acc=0.8432 | time/epoch=774.89s


[imdb][r=16] Epoch 7/10: 100%|██████████| 2500/2500 [12:02<00:00,  3.46it/s, loss=0.000275]


[imdb][r=16] Epoch 7 | train_loss=0.0040 | val_acc=0.8424 | time/epoch=755.74s


[imdb][r=16] Epoch 8/10: 100%|██████████| 2500/2500 [11:53<00:00,  3.50it/s, loss=1.01e-6] 


[imdb][r=16] Epoch 8 | train_loss=0.0012 | val_acc=0.8464 | time/epoch=745.77s


[imdb][r=16] Epoch 9/10: 100%|██████████| 2500/2500 [12:25<00:00,  3.35it/s, loss=0]       


[imdb][r=16] Epoch 9 | train_loss=0.0000 | val_acc=0.8436 | time/epoch=778.97s


[imdb][r=16] Epoch 10/10: 100%|██████████| 2500/2500 [11:52<00:00,  3.51it/s, loss=0]       


[imdb][r=16] Epoch 10 | train_loss=0.0000 | val_acc=0.8436 | time/epoch=744.97s
[imdb][r=16] Training completed! best_val_acc=0.8572, avg_time/epoch=737.97s, converge_epoch=1, trainable_params=25773314, trainable_ratio=37.7453%
[imdb][r=16] Final test accuracy: 0.8504
