In [1]:
! pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [2]:
import json

def save_svd_config(config, file_path="svd_config.json"):
    with open(file_path, "w") as f:
        json.dump(config, f, indent=4)

def load_svd_config(file_path="svd_config.json"):
    with open(file_path, "r") as f:
        return json.load(f)

In [3]:
import numpy as np

def compute_effective_rank(matrix):
    """
    Compute the effective rank of a matrix based on the definition provided.
    """
    _, S, _ = torch.linalg.svd(matrix, full_matrices=False)
    singular_values = S.cpu().numpy()

    # Compute the singular value distribution (p_k)
    l1_norm = np.sum(np.abs(singular_values))
    p_k = singular_values / l1_norm

    # Compute the Shannon entropy
    H = -np.sum(p_k * np.log(p_k + 1e-10))  # Add a small constant to avoid log(0)

    # Compute the effective rank
    effective_rank = np.exp(H)

    return effective_rank

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset
from tqdm import tqdm
import time

# Make sure to run on GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


###################################################
# 1. Define a PyTorch Dataset for AG News
###################################################
class AGNewsDataset(Dataset):
    """
    PyTorch dataset wrapper for the AG News dataset.
    Each example is converted to a text-to-text format.
    """
    def __init__(self, hf_dataset, split, tokenizer, label_mapping):
        """
        hf_dataset: the Hugging Face dataset loaded via load_dataset("ag_news")
        split: "train" or "test"
        tokenizer: a T5Tokenizer instance
        label_mapping: a dict mapping integer labels to string labels, e.g. {0:"World", ...}
        """
        self.dataset = hf_dataset[split]
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get the sample
        sample = self.dataset[idx]
        # The AG News dataset has a "text" field and a "label" field
        text = sample["text"]
        label_id = sample["label"]
        # Create input prompt and target text. You can modify the prompt as desired.
        input_text = "classify: " + text
        target_text = self.label_mapping[label_id]
        return input_text, target_text


###################################################
# 2. Collate Function
###################################################
def collate_fn(batch, tokenizer, max_source_length=512, max_target_length=16):
    """
    Tokenize the batch of input and target texts.
    Returns a dictionary with tokenized input_ids, attention_mask, and labels.
    """
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(list(inputs), padding=True, truncation=True, max_length=max_source_length, return_tensors="pt")
    target_encodings = tokenizer(list(targets), padding=True, truncation=True, max_length=max_target_length, return_tensors="pt")
    # T5 uses the "labels" field as the target. (Note: the tokenizer converts target text to input_ids.)
    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings


###################################################
# 3. Training and Evaluation Functions
###################################################
def train_finetune_t5():
    # Load the AG News dataset from Hugging Face
    hf_dataset = load_dataset("ag_news")

    # Define the label mapping
    label_mapping = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}

    # Load pretrained T5 tokenizer and model (T5-small)
    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    model = model.to(device)

    # Create PyTorch datasets for train and test splits
    train_dataset = AGNewsDataset(hf_dataset, "train", tokenizer, label_mapping)
    test_dataset = AGNewsDataset(hf_dataset, "test", tokenizer, label_mapping)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
                              collate_fn=lambda batch: collate_fn(batch, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,
                             collate_fn=lambda batch: collate_fn(batch, tokenizer))

    # Prepare optimizer (full fine-tuning; all model parameters are updated)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    num_epochs = 3

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", leave=True)
        start_time = time.time()

        for batch in progress_bar:
            # Move batch to device
            for key, val in batch.items():
                batch[key] = val.to(device)

            outputs = model(**batch)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Estimate time remaining
            elapsed_time = time.time() - start_time
            remaining_time = elapsed_time / (progress_bar.n + 1) * (len(train_loader) - progress_bar.n)
            progress_bar.set_postfix(loss=loss.item(), eta=f"{remaining_time:.2f}s")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

    # Save the fine-tuned model (model1)
    torch.save(model.state_dict(), "t5_finetuned_agnews.pt")
    print("Model saved as 't5_finetuned_agnews.pt'.")

    return model, tokenizer, test_loader


def evaluate(model, tokenizer, test_loader):
    """
    Evaluate the fine-tuned model on the test set.
    Uses model.generate() to produce output text and compares to the expected label.
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in test_loader:
            for key, val in batch.items():
                batch[key] = val.to(device)
            # Generate predictions; set a max_length suitable for the short target texts
            generated_ids = model.generate(batch["input_ids"],
                                           attention_mask=batch["attention_mask"],
                                           max_length=16)
            predictions = [tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_ids]
            # Decode the ground truth labels from the batch's labels field
            # (Here we assume that labels are token IDs representing a short text such as "Sports")
            targets = [tokenizer.decode(t, skip_special_tokens=True).strip() for t in batch["labels"]]

            for pred, target in zip(predictions, targets):
                if pred.lower() == target.lower():
                    correct += 1
                total += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"Test Accuracy: {accuracy*100:.2f}%")
    return accuracy


###################################################
# 5. Main: Train, Check, and Evaluate
###################################################
if __name__ == "__main__":
    # Train and fine-tune T5 on AG News (model1)
    model1, tokenizer, test_loader = train_finetune_t5()

    # Evaluate the reloaded model (model2) on the test set.
    evaluate(model1, tokenizer, test_loader)

Epoch 1/3: 100%|██████████| 15000/15000 [23:19<00:00, 10.72batch/s, eta=0.09s, loss=0.00591]


Epoch 1/3 - Average Loss: 0.1053


Epoch 2/3: 100%|██████████| 15000/15000 [22:59<00:00, 10.87batch/s, eta=0.18s, loss=0.203]


Epoch 2/3 - Average Loss: 0.0527


Epoch 3/3: 100%|██████████| 15000/15000 [22:56<00:00, 10.90batch/s, eta=0.18s, loss=0.00669]


Epoch 3/3 - Average Loss: 0.0431
Model saved as 't5_finetuned_agnews.pt'.
Test Accuracy: 93.99%


In [6]:
import os
import json
import csv
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset
from tqdm import tqdm

torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###################################################
# 1. Helper Functions for SVD and Parameter Management
###################################################

def decompose_weight_matrix(weight: torch.Tensor, top_k: int):
    """
    Perform SVD on a 2D weight matrix and split into:
      - top_k singular vectors (treated as frozen/buffers)
      - the rest (treated as trainable)
    Returns a dictionary containing:
      {
        "U_high": ...  # buffer
        "S_high": ...  # buffer
        "V_high": ...  # buffer
        "U_low": ...   # parameter
        "S_low": ...   # parameter
        "V_low": ...   # parameter
        "rank_high": top_k
      }
    """
    device_local = weight.device
    W = weight.to(torch.float32)  # ensure float32 for SVD
    U, S, Vt = torch.linalg.svd(W, full_matrices=False)
    # Ensure we don’t ask for more than available
    k = min(top_k, S.shape[0])

    # High subspace (frozen)
    U_high = U[:, :k].detach().to(device_local)
    S_high = S[:k].detach().to(device_local)
    V_high = Vt[:k, :].detach().to(device_local)

    # Low subspace (trainable)
    U_low = U[:, k:].detach().to(device_local)
    S_low = S[k:].detach().to(device_local)
    V_low = Vt[k:, :].detach().to(device_local)

    return {
        "U_high": U_high,
        "S_high": S_high,
        "V_high": V_high,
        "U_low": nn.Parameter(U_low),
        "S_low": nn.Parameter(S_low),
        "V_low": nn.Parameter(V_low),
        "rank_high": k
    }


def reconstruct_weight_matrix(svd_dict):
    """
    Reconstruct the full weight matrix:
       W = U_high * diag(S_high) * V_high^T + U_low * diag(S_low) * V_low^T
    """
    U_high = svd_dict["U_high"]
    S_high = svd_dict["S_high"]
    V_high = svd_dict["V_high"]
    U_low = svd_dict["U_low"]
    S_low = svd_dict["S_low"]
    V_low = svd_dict["V_low"]

    if U_high.shape[1] > 0 and S_high.shape[0] > 0:
        high_part = torch.mm(U_high * S_high.unsqueeze(0), V_high)
    else:
        high_part = torch.zeros(U_low.size(0), V_low.size(1), device=U_high.device)

    if U_low.shape[1] > 0 and S_low.shape[0] > 0:
        US_low = U_low * S_low.unsqueeze(0)
        low_part = torch.mm(US_low, V_low)
    else:
        low_part = torch.zeros(U_high.size(0), V_high.size(1), device=U_low.device)

    return high_part + low_part


def check_reconstruction_error(weight, svd_dict, atol=1e-5):
    # Move the weight to the same device as the U_high buffer
    target_device = svd_dict["U_high"].device
    weight = weight.to(target_device)
    W_recon = reconstruct_weight_matrix(svd_dict)
    # Ensure reconstruction is also on the target device
    W_recon = W_recon.to(target_device)
    error = torch.norm(weight - W_recon) / torch.norm(weight)
    if error > atol:
        print(f"Warning: Reconstruction error {error:.2e} exceeds tolerance {atol}")
    return error


def project_gradient_to_orthogonal_space(svd_dict):
    """
    Remove from the gradients of the low subspace any component that lies
    in the high subspace.
    """
    if (svd_dict["U_low"].grad is None and
        svd_dict["S_low"].grad is None and
        svd_dict["V_low"].grad is None):
        return

    U_high = svd_dict["U_high"]
    V_high = svd_dict["V_high"]

    if svd_dict["U_low"].grad is not None:
        dU = svd_dict["U_low"].grad
        proj = U_high @ (U_high.transpose(0,1) @ dU)
        dU.sub_(proj)

    if svd_dict["V_low"].grad is not None:
        dV = svd_dict["V_low"].grad
        proj = (dV @ V_high.transpose(0,1)) @ V_high
        dV.sub_(proj)
    # We leave S_low unchanged


###################################################
# 2. T5 Model Subclass with SVD (Only for Selected Parameters)
###################################################

class T5WithSVD(T5ForConditionalGeneration):
    """
    Subclass that, on initialization, decomposes selected weight matrices via SVD.
    Only parameters specified in the svd_config are decomposed.
    For each such 2D weight, we freeze the top singular vectors (50% by default)
    and register the lower half (trainable) as parameters.

    Additionally, we pre-compute the module mapping for faster weight injection.
    """
    def __init__(self, config: T5Config, svd_config=None, initialize_svd=True):
        super().__init__(config)
        # svd_config is a dict mapping full parameter names to top_k values.
        self.svd_config = svd_config if svd_config is not None else {}
        self.name_mapping = {}         # maps original name -> safe name
        self.svd_original_mapping = {} # maps safe name -> original name
        self.svd_params = nn.ModuleDict()
        self.svd_module_mapping = {}   # maps safe name -> (module, attribute_name)
        if initialize_svd:
          self._initialize_svd_parameters()

    def reinitialize_svd(self):
        """
        Reinitialize the SVD decomposition on the current (loaded) weights.
        Before reinitialization, store a copy of the original weights for each target parameter,
        then after reinitialization, check and print the reconstruction error.
        """
        # Save original weights for each parameter to be decomposed.
        self._original_weights = {}
        for orig_name in self.svd_config.keys():
            # Retrieve from the model's state_dict; ensure it is on the correct device.
            self._original_weights[orig_name] = self.state_dict()[orig_name].clone().to(device)

        # Clear previous SVD mappings.
        self.name_mapping = {}
        self.svd_original_mapping = {}
        self.svd_params = nn.ModuleDict()
        self.svd_module_mapping = {}
        # Reinitialize the SVD decomposition using the current weights.
        self._initialize_svd_parameters()

        # Now, for each decomposed parameter, compute and print the reconstruction error.
        for orig_name, safe_name in self.name_mapping.items():
            orig_weight = self._original_weights[orig_name]
            svd_dict = {
                "U_high": getattr(self, f"{safe_name}_U_high"),
                "S_high": getattr(self, f"{safe_name}_S_high"),
                "V_high": getattr(self, f"{safe_name}_V_high"),
                "U_low": self.svd_params[safe_name].U_low,
                "S_low": self.svd_params[safe_name].S_low,
                "V_low": self.svd_params[safe_name].V_low
            }
            error = check_reconstruction_error(orig_weight, svd_dict)
            print(f"Reconstruction error for {orig_name}: {error:.2e}")

    def _initialize_svd_parameters(self):
        # Iterate over all parameters
        for name, param in list(self.named_parameters()):
            if len(param.shape) == 2 and name in self.svd_config and self.svd_config[name] > 0:
                top_k = self.svd_config[name]
                print(f"[SVD Init] Decomposing {name} with top_k={top_k}")
                svd_dict = decompose_weight_matrix(param.data, top_k=top_k)
                safe_name = name.replace(".", "_")
                self.name_mapping[name] = safe_name
                self.svd_original_mapping[safe_name] = name

                # Compute the residual: the difference between the original weight and its SVD reconstruction.
                # residual = (param.data - reconstruct_weight_matrix(svd_dict)).detach()
                # Register the residual as a buffer (no gradients).
                # self.register_buffer(f"{safe_name}_residual", residual)

                # Register buffers for the high subspace
                self.register_buffer(f"{safe_name}_U_high", svd_dict["U_high"])
                self.register_buffer(f"{safe_name}_S_high", svd_dict["S_high"])
                self.register_buffer(f"{safe_name}_V_high", svd_dict["V_high"])

                # Create a module to hold the low subspace trainable parameters
                module_svd = nn.Module()
                module_svd.U_low = nn.Parameter(svd_dict["U_low"])
                module_svd.S_low = nn.Parameter(svd_dict["S_low"])
                module_svd.V_low = nn.Parameter(svd_dict["V_low"])
                module_svd.rank_high = svd_dict["rank_high"]
                module_svd.safe_name = safe_name
                self.svd_params[safe_name] = module_svd

                # Freeze the original parameter
                param.requires_grad = False

                # Pre-compute and store the module and attribute name for quick access
                mod, attr = self._get_module_by_name(name)
                if mod is not None:
                    self.svd_module_mapping[safe_name] = (mod, attr)
            # For parameters not in svd_config, leave them trainable (do nothing)

    def _reconstruct_weight(self, original_name):
        safe_name = self.name_mapping[original_name]
        U_high = getattr(self, f"{safe_name}_U_high")
        S_high = getattr(self, f"{safe_name}_S_high")
        V_high = getattr(self, f"{safe_name}_V_high")
        module_svd = self.svd_params[safe_name]
        U_low = module_svd.U_low
        S_low = module_svd.S_low
        V_low = module_svd.V_low
        svd_dict = {"U_high": U_high, "S_high": S_high, "V_high": V_high,
                    "U_low": U_low, "S_low": S_low, "V_low": V_low}
        W = reconstruct_weight_matrix(svd_dict)

        # Retrieve the residual that was stored during initialization.
        # residual = getattr(self, f"{safe_name}_residual").detach()

        # return W + residual

        return W

    def forward(self, *args, **kwargs):
        # Instead of recomputing the module mapping each time,
        # iterate over the precomputed svd_module_mapping.
        for safe_name, (module, attr) in self.svd_module_mapping.items():
            original_name = self.svd_original_mapping[safe_name]
            W = self._reconstruct_weight(original_name)
            # if attr in module._parameters:
            #     print(type(module._parameters))
            #     print(module._parameters)
            #     print(attr)
            #     module._parameters.pop(attr)
            # setattr(module, attr, W)
            # print(module._parameters)
            with torch.no_grad():
                getattr(module, attr).data.copy_(W)
        return super().forward(*args, **kwargs)

    def _get_module_by_name(self, name):
        """
        Given a full parameter name (e.g. "encoder.block.0.layer.0.SelfAttention.q.weight"),
        return (module, attribute_name) where module.attribute_name is that parameter.
        """
        parts = name.split(".")
        attr = parts[-1]
        mod = self
        for p in parts[:-1]:
            if hasattr(mod, p):
                mod = getattr(mod, p)
            elif p.isdigit():
                mod = mod[int(p)]
            else:
                return None, None
        return mod, attr

    def project_gradients(self):
        for safe_name, module_svd in self.svd_params.items():
            svd_dict = {
                "U_high": getattr(self, f"{safe_name}_U_high"),
                "S_high": getattr(self, f"{safe_name}_S_high"),
                "V_high": getattr(self, f"{safe_name}_V_high"),
                "U_low": module_svd.U_low,
                "S_low": module_svd.S_low,
                "V_low": module_svd.V_low,
            }
            project_gradient_to_orthogonal_space(svd_dict)

###################################################
# 3. Utility: Auto-generate SVD Config for Target Parameters
###################################################
# def auto_generate_target_svd_config(model):
#     """
#     Given a model, generate an SVD configuration dictionary only for parameters that contain one of the
#     following substrings:
#       - SelfAttention.q.weight
#       - SelfAttention.k.weight
#       - SelfAttention.v.weight
#       - SelfAttention.o.weight
#       - DenseReluDense.wi.weight
#       - DenseReluDense.wo.weight
#     For each such 2D parameter, set:
#          top_k = floor(min(dim0, dim1) / 2)
#     """
#     target_patterns = [
#         "SelfAttention.q.weight",
#         "SelfAttention.k.weight",
#         "SelfAttention.v.weight",
#         "SelfAttention.o.weight",
#         "DenseReluDense.wi.weight",
#         "DenseReluDense.wo.weight"
#     ]
#     config = {}
#     for name, param in model.named_parameters():
#         if any(pat in name for pat in target_patterns) and len(param.shape) == 2:
#             # rank = min(param.shape)
#             # top_k = rank // 2  # freeze top 50%
#             # if top_k > 0:
#             #     config[name] = top_k
#             effective_rank = compute_effective_rank(param.data)
#             top_k = int(np.floor(effective_rank))
#             full_rank = min(param.shape)
#             if top_k > full_rank:
#                 top_k = full_rank
#             config[name] = top_k
#     save_svd_config(config)
#     return config

def auto_generate_target_svd_config(model, tokenizer, n_samples=128, batch_size=8, num_batches=5):
    """
    For each target parameter (matching target_patterns), compute the adaptive retention ratio based on
    the importance I(W) measured using actual inputs from the AGNews test set.

    For each target parameter W (shape: (d, m), let d = min(W.shape)).
    For each such parameter:
       - Run num_batches of AGNews test data through the model with hooks to capture the input X for
         the module corresponding to W.
       - Concatenate the captured X from all batches to form a matrix X of shape (m, total_samples).
       - Compute I(W) = average cosine similarity between columns of X and Y = W @ X.
    Then normalize importance by the mean and set:
       CR(W) = 1 + (I(W)/mean(I(W)))*((d/2) - 1)
       k = round(CR(W) * d / 2)
    Clamp k between 1 and d.
    Return a dictionary mapping parameter names to top_k.
    """
    target_patterns = [
        "SelfAttention.q.weight",
        "SelfAttention.k.weight",
        "SelfAttention.v.weight",
        "SelfAttention.o.weight",
        "DenseReluDense.wi.weight",
        "DenseReluDense.wo.weight"
    ]
    # Dictionary to store importance for each target parameter.
    importance_dict = {}
    # Dictionary to store captured inputs for each target parameter.
    captured_inputs = {name: [] for name, param in model.named_parameters()
                         if any(pat in name for pat in target_patterns) and len(param.shape)==2}

    # Create hooks to capture inputs for each target module.
    hooks = {}
    def get_hook(name):
        def hook(module, input, output):
            # input[0] might have shape (batch_size, seq_length, in_features)
            X = input[0]
            # Flatten the batch and sequence dimensions into one:
            X = X.reshape(-1, X.shape[-1])  # shape: (batch_size * seq_length, in_features)
            # Transpose so that columns represent individual samples:
            captured_inputs[name].append(X.transpose(0, 1).detach())
        return hook

    # For each target parameter, register a hook on its parent module.
    for name, param in model.named_parameters():
        if any(pat in name for pat in target_patterns) and len(param.shape)==2:
            mod, attr = model._get_module_by_name(name)
            if mod is not None:
                hooks[name] = mod.register_forward_hook(get_hook(name))

    # Now run a few batches of AGNews test data.
    from datasets import load_dataset
    agnews = load_dataset("ag_news", split="test")
    # Use a simple dataset: each sample is a string "classify: <text>".
    inputs = ["classify: " + sample["text"] for sample in agnews.select(range(n_samples))]
    encodings = tokenizer(inputs, padding=True, truncation=True, max_length=512, return_tensors="pt")
    # Wrap the BatchEncoding in a custom Dataset
    class BatchEncodingDataset(Dataset):
        def __init__(self, encodings):
            self.encodings = encodings
        def __len__(self):
            return self.encodings["input_ids"].shape[0]
        def __getitem__(self, idx):
            return {key: val[idx] for key, val in self.encodings.items()}

    agnews_dataset = BatchEncodingDataset(encodings)
    agnews_loader = DataLoader(agnews_dataset, batch_size=batch_size)
    # agnews_loader = DataLoader(encodings, batch_size=batch_size)

    model = model.to(device)
    model.eval()
    batches = 0
    with torch.no_grad():
        for batch in agnews_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            # _ = model(**batch)  # forward pass to trigger hooks
            # batches += 1
            # if batches >= num_batches:
            #     break

            batch_size = batch["input_ids"].shape[0]
            # Create a dummy decoder input using the model's decoder_start_token_id.
            # T5 usually uses 0 or the value from config.decoder_start_token_id.
            dummy_decoder_input_ids = torch.full(
                (batch_size, 1),
                model.config.decoder_start_token_id,
                device=device,
                dtype=batch["input_ids"].dtype
            )
            # Forward pass with both encoder and decoder inputs.
            _ = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                decoder_input_ids=dummy_decoder_input_ids
            )

    # Remove hooks.
    for h in hooks.values():
        h.remove()

    # Now compute importance for each target parameter.
    for name in captured_inputs.keys():
        # Concatenate captured inputs along last dimension.
        X = torch.cat(captured_inputs[name], dim=1).to(device)  # shape: (in_features, total_samples)
        W = model.state_dict()[name].to(device)
        Y = torch.mm(W, X)

        # Determine m = min(W.shape) and slice both X and Y to the first m rows.
        m = min(W.shape)
        X_mod = X[:m, :]
        Y_mod = Y[:m, :]

        X_norm = X_mod / (torch.norm(X_mod, dim=0, keepdim=True) + 1e-8)
        Y_norm = Y_mod / (torch.norm(Y_mod, dim=0, keepdim=True) + 1e-8)
        cosine_sim = torch.sum(X_norm * Y_norm, dim=0)
        I_W = torch.mean(cosine_sim).item()
        importance_dict[name] = I_W

    mean_importance = np.mean(list(importance_dict.values()))
    config = {}
    for name, param in model.named_parameters():
        if name in importance_dict:
            d = min(param.shape)
            I_W = importance_dict[name]
            I_n = I_W / (mean_importance + 1e-8)
            mrr = d / 2.0 # 1.0
            trr = d # d / 2.0
            CR = mrr + I_n * (trr - mrr)
            # As explained: full params of W is 2*d^2 (for square W) and retained params is 2*d*k,
            # so we set k/d = CR  => k = CR * d.
            # k = int(round(CR * d))
            k = int(round(CR))
            k = max(1, min(k, d))
            config[name] = k
    save_svd_config(config)
    return config

###################################################
# 4. Amazon Reviews Dataset (Using amazon_polarity)
###################################################
class AmazonReviewsDataset(Dataset):
    """
    Wraps the amazon_polarity dataset. Assumes each sample has a "content" field
    and a "label" field. We convert each example into a text-to-text format.
    """
    def __init__(self, hf_dataset, split, tokenizer, label_mapping):
        # self.dataset = hf_dataset[split]
        self.dataset = hf_dataset[split].shuffle(seed=42).select(range(3600))  # Select only 3.6k samples
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        text = sample["content"]
        label = sample["label"]
        input_text = "classify as positive or negative: " + text
        target_text = self.label_mapping[label]
        return input_text, target_text

def collate_fn_fn(batch, tokenizer, max_source_length=512, max_target_length=16):
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(list(inputs), padding=True, truncation=True, max_length=max_source_length, return_tensors="pt")
    target_encodings = tokenizer(list(targets), padding=True, truncation=True, max_length=max_target_length, return_tensors="pt")
    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings

###################################################
# 5. Training and Saving the SVD Model on Amazon Reviews
###################################################
def train_svd_model():
    # Load Amazon polarity dataset
    hf_dataset = load_dataset("amazon_polarity")
    # Define label mapping (assuming 0: negative, 1: positive)
    label_mapping = {0: "negative", 1: "positive"}

    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)
    config.use_cache = False  # disable cache for training

    # Load a base T5 model to auto-generate the target SVD config.
    # base_model = T5ForConditionalGeneration.from_pretrained(model_name)
    base_model = T5WithSVD(config, svd_config={}, initialize_svd=False)
    base_model.load_state_dict(torch.load('t5_finetuned_agnews.pt', map_location=device), strict=False)
    base_model = base_model.to(device)
    # target_svd_config = auto_generate_target_svd_config(base_model)
    target_svd_config = auto_generate_target_svd_config(base_model, T5Tokenizer.from_pretrained("t5-small"))
    print("Auto-generated target SVD config:")
    for k, v in target_svd_config.items():
        print(f"  {k}: freeze top {v} singular vectors")

    # Initialize our custom SVD model with target_svd_config.
    model = T5WithSVD(config, svd_config=target_svd_config)
    # Load pretrained weights into our SVD model.
    model.load_state_dict(torch.load('t5_finetuned_agnews.pt', map_location=device), strict=False)
    model.reinitialize_svd()
    model = model.to(device)

    # # Load the original AGNews state dictionary
    # orig_state = torch.load('t5_finetuned_agnews.pt', map_location=device)

    # # For each parameter in the original state, compare with the corresponding effective parameter in model.
    # for name, orig_param in orig_state.items():
    #     # If this parameter was decomposed (present in our svd_config), then use our reconstruction function.
    #     if name in model.svd_config:
    #         # Compute effective weight from SVD (including residual)
    #         effective_weight = model._reconstruct_weight(name)
    #         # Compute relative error
    #         error = torch.norm(orig_param.to(device) - effective_weight) / torch.norm(orig_param.to(device))
    #         print(f"{name} (decomposed): relative error = {error.item():.2e}")
    #     else:
    #         # Otherwise, compare directly.
    #         try:
    #             model_param = model.state_dict()[name]
    #             error = torch.norm(orig_param.to(device) - model_param) / torch.norm(orig_param.to(device))
    #             print(f"{name} (not decomposed): relative error = {error.item():.2e}")
    #         except KeyError:
    #             print(f"{name} is not present in the current model state_dict.")

    # torch.save(model.state_dict(), "t5_svd_amazon.pt")

    # Create datasets and dataloaders
    train_dataset = AmazonReviewsDataset(hf_dataset, "train", tokenizer, label_mapping)
    test_dataset = AmazonReviewsDataset(hf_dataset, "test", tokenizer, label_mapping)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
                              collate_fn=lambda batch: collate_fn_fn(batch, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,
                             collate_fn=lambda batch: collate_fn_fn(batch, tokenizer))

    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    num_epochs = 3  # adjust as needed

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", leave=True)
        start_time = time.time()

        for batch in progress_bar:
            for key, val in batch.items():
                batch[key] = val.to(device)
            outputs = model(**batch, use_cache=False)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            model.project_gradients()  # ensure gradients remain in correct subspace
            optimizer.step()

            total_loss += loss.item()
            elapsed_time = time.time() - start_time
            remaining_time = elapsed_time / (progress_bar.n + 1) * (len(train_loader) - progress_bar.n)
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", eta=f"{remaining_time:.2f}s")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

    # Save the fine-tuned model (with SVD modifications)
    torch.save(model.state_dict(), "t5_svd_amazon.pt")
    print("Model saved as 't5_svd_amazon.pt'")
    return model, tokenizer, test_loader

###################################################
# 6. Inference
###################################################
def inference_svd_model():
    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)
    config.use_cache = False
    # Re-generate the same target SVD configuration
    base_model = T5ForConditionalGeneration.from_pretrained(model_name)
    target_svd_config = auto_generate_target_svd_config(base_model)
    model = T5WithSVD(config, svd_config=target_svd_config)
    model.load_state_dict(torch.load("t5_svd_amazon.pt"), strict=False)
    model = model.to(device)
    model.eval()

    # Try a generation example – here we provide a sample review.
    input_text = "classify: This product exceeded my expectations and works perfectly!"
    input_enc = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model.generate(**input_enc, max_length=16)
    print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

###################################################
# 6. Evaluation on Test Set
###################################################
def evaluate_model(model, tokenizer, test_loader):
    model.eval()
    total, correct = 0, 0
    for batch in tqdm(test_loader, desc="Evaluating", unit="batch"):
        # Move batch tensors to device
        for key, val in batch.items():
            batch[key] = val.to(device)
        with torch.no_grad():
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16
            )
        # Decode predictions and targets
        predictions = [tokenizer.decode(g, skip_special_tokens=True).strip().lower()
                       for g in generated_ids]
        targets = [tokenizer.decode(label, skip_special_tokens=True).strip().lower()
                   for label in batch["labels"]]
        for pred, target in zip(predictions, targets):
            total += 1
            if pred == target:
                correct += 1
    accuracy = correct / total if total > 0 else 0
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

###################################################
# 7. Main
###################################################
if __name__ == "__main__":
    model, tokenizer, test_loader = train_svd_model()
    evaluate_model(model, tokenizer, test_loader)
    # inference_svd_model()

  base_model.load_state_dict(torch.load('t5_finetuned_agnews.pt', map_location=device), strict=False)


Auto-generated target SVD config:
  encoder.block.0.layer.0.SelfAttention.q.weight: freeze top 248 singular vectors
  encoder.block.0.layer.0.SelfAttention.k.weight: freeze top 1 singular vectors
  encoder.block.0.layer.0.SelfAttention.v.weight: freeze top 1 singular vectors
  encoder.block.0.layer.0.SelfAttention.o.weight: freeze top 1 singular vectors
  encoder.block.0.layer.1.DenseReluDense.wi.weight: freeze top 340 singular vectors
  encoder.block.0.layer.1.DenseReluDense.wo.weight: freeze top 1 singular vectors
  encoder.block.1.layer.0.SelfAttention.q.weight: freeze top 512 singular vectors
  encoder.block.1.layer.0.SelfAttention.k.weight: freeze top 512 singular vectors
  encoder.block.1.layer.0.SelfAttention.v.weight: freeze top 453 singular vectors
  encoder.block.1.layer.0.SelfAttention.o.weight: freeze top 1 singular vectors
  encoder.block.1.layer.1.DenseReluDense.wi.weight: freeze top 294 singular vectors
  encoder.block.1.layer.1.DenseReluDense.wo.weight: freeze top 512 s

  model.load_state_dict(torch.load('t5_finetuned_agnews.pt', map_location=device), strict=False)


[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.q.weight with top_k=248
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.k.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.v.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.o.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wi.weight with top_k=340
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wo.weight with top_k=1
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.q.weight with top_k=512
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.k.weight with top_k=512
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.v.weight with top_k=453
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.o.weight with top_k=1
[SVD Init] Decomposing encoder.block.1.layer.1.DenseReluDense.wi.weight with top_k=294
[SVD Init] Decomposing encoder.block.1.layer.1.DenseReluDense.wo.weig

Epoch 1/3: 100%|██████████| 450/450 [04:48<00:00,  1.56batch/s, eta=0.64s, loss=0.2169]


Epoch 1/3 - Average Loss: 1.4814


Epoch 2/3: 100%|██████████| 450/450 [04:49<00:00,  1.55batch/s, eta=0.64s, loss=0.2639]


Epoch 2/3 - Average Loss: 0.2341


Epoch 3/3: 100%|██████████| 450/450 [04:49<00:00,  1.55batch/s, eta=0.64s, loss=0.0687]


Epoch 3/3 - Average Loss: 0.1825
Model saved as 't5_svd_amazon.pt'


Evaluating: 100%|██████████| 450/450 [00:34<00:00, 12.86batch/s]

Test Accuracy: 88.94%





In [2]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset
from tqdm import tqdm

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###################################################
# 1. Define a PyTorch Dataset for Amazon Reviews
###################################################
class AmazonReviewsDataset(Dataset):
    """
    A PyTorch dataset wrapper for the amazon_polarity dataset.
    Each sample is converted to a text-to-text format.
    """
    def __init__(self, hf_dataset, split, tokenizer, label_mapping):
        """
        hf_dataset: the Hugging Face dataset loaded via load_dataset("amazon_polarity")
        split: "train" or "test"
        tokenizer: a T5Tokenizer instance
        label_mapping: a dict mapping integer labels to string labels, e.g. {0:"negative", 1:"positive"}
        """
        self.dataset = hf_dataset[split].shuffle(seed=42).select(range(3600))
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        # The amazon_polarity dataset uses "content" for the review text and "label" for polarity.
        text = sample["content"]
        label = sample["label"]
        input_text = "classify as positive or negative: " + text
        target_text = self.label_mapping[label]
        return input_text, target_text

###################################################
# 2. Collate Function
###################################################
def collate_fn_amazon(batch, tokenizer, max_source_length=512, max_target_length=16):
    """
    Tokenize a batch of input and target texts.
    Returns a dictionary with input_ids, attention_mask, and labels.
    """
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(list(inputs), padding=True, truncation=True, max_length=max_source_length, return_tensors="pt")
    target_encodings = tokenizer(list(targets), padding=True, truncation=True, max_length=max_target_length, return_tensors="pt")
    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings

###################################################
# 3. Training and Evaluation Functions for Amazon Reviews
###################################################
def train_finetune_amazon():
    # Load the amazon_polarity dataset from Hugging Face
    hf_dataset = load_dataset("amazon_polarity")

    # Define the label mapping (0: negative, 1: positive)
    label_mapping = {0: "negative", 1: "positive"}

    # Load pretrained T5 tokenizer and model (T5-small)
    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    state_dict = torch.load("t5_finetuned_agnews.pt", map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model = model.to(device)

    # Create PyTorch datasets for train and test splits
    train_dataset = AmazonReviewsDataset(hf_dataset, "train", tokenizer, label_mapping)
    test_dataset = AmazonReviewsDataset(hf_dataset, "test", tokenizer, label_mapping)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
                              collate_fn=lambda batch: collate_fn_amazon(batch, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,
                             collate_fn=lambda batch: collate_fn_amazon(batch, tokenizer))

    # Prepare optimizer (full fine-tuning; all parameters will be updated)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    num_epochs = 3  # Adjust the number of epochs as needed

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", leave=True)
        start_time = time.time()

        for batch in progress_bar:
            # Move batch tensors to device
            for key, val in batch.items():
                batch[key] = val.to(device)
            outputs = model(**batch)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            elapsed_time = time.time() - start_time
            remaining_time = elapsed_time / (progress_bar.n + 1) * (len(train_loader) - progress_bar.n)
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", eta=f"{remaining_time:.2f}s")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

    # Save the fine-tuned model on Amazon Reviews
    torch.save(model.state_dict(), "t5_finetuned_amazon.pt")
    print("Model saved as 't5_finetuned_amazon.pt'.")
    return model, tokenizer, test_loader


def evaluate(model, tokenizer, test_loader):
    """
    Evaluate the fine-tuned model on the test set.
    The model generates predictions which are compared (after decoding) to the expected labels.
    """
    model.eval()
    total, correct = 0, 0

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating", unit="batch"):
            for key, val in batch.items():
                batch[key] = val.to(device)
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16
            )
            # Decode predictions and targets
            predictions = [tokenizer.decode(g, skip_special_tokens=True).strip().lower() for g in generated_ids]
            targets = [tokenizer.decode(t, skip_special_tokens=True).strip().lower() for t in batch["labels"]]
            for pred, target in zip(predictions, targets):
                total += 1
                if pred == target:
                    correct += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"Test Accuracy on Amazon Reviews: {accuracy*100:.2f}%")
    return accuracy


###################################################
# 4. Main: Train and Evaluate on Amazon Reviews
###################################################
if __name__ == "__main__":
    model, tokenizer, test_loader = train_finetune_amazon()
    evaluate(model, tokenizer, test_loader)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.81k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/260M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/258M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/255M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/254M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/117M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3600000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/400000 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

  state_dict = torch.load("t5_finetuned_agnews.pt", map_location=device)
Epoch 1/3:   0%|          | 0/450 [00:00<?, ?batch/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Epoch 1/3: 100%|██████████| 450/450 [00:53<00:00,  8.37batch/s, eta=0.12s, loss=0.1122]


Epoch 1/3 - Average Loss: 0.6280


Epoch 2/3: 100%|██████████| 450/450 [00:52<00:00,  8.52batch/s, eta=0.12s, loss=0.0347]


Epoch 2/3 - Average Loss: 0.1500


Epoch 3/3: 100%|██████████| 450/450 [00:53<00:00,  8.45batch/s, eta=0.12s, loss=0.0778]


Epoch 3/3 - Average Loss: 0.1150
Model saved as 't5_finetuned_amazon.pt'.


Evaluating: 100%|██████████| 450/450 [00:24<00:00, 18.20batch/s]

Test Accuracy on Amazon Reviews: 91.00%





In [3]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset
from tqdm import tqdm

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###################################################
# 1. AG News Dataset Class
###################################################
class AGNewsDataset(Dataset):
    """
    A PyTorch dataset wrapper for the AG News dataset.
    Each example is converted to a text-to-text format.
    """
    def __init__(self, hf_dataset, split, tokenizer, label_mapping):
        """
        hf_dataset: the Hugging Face dataset loaded via load_dataset("ag_news")
        split: "train" or "test"
        tokenizer: a T5Tokenizer instance
        label_mapping: a dict mapping integer labels to string labels, e.g. {0: "World", ...}
        """
        self.dataset = hf_dataset[split]
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        text = sample["text"]
        label = sample["label"]
        # Use the same prompt as used during training
        input_text = "classify: " + text
        target_text = self.label_mapping[label]
        return input_text, target_text

###################################################
# 2. Collate Function
###################################################
def collate_fn_agnews(batch, tokenizer, max_source_length=512, max_target_length=16):
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(list(inputs), padding=True, truncation=True, max_length=max_source_length, return_tensors="pt")
    target_encodings = tokenizer(list(targets), padding=True, truncation=True, max_length=max_target_length, return_tensors="pt")
    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings

###################################################
# 3. Evaluation Function
###################################################
def evaluate_agnews(model, tokenizer, test_loader, num_examples=5):
    model.eval()
    total, correct = 0, 0
    examples_printed = 0

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating AG News", unit="batch"):
            # Move the batch to the device
            for key, val in batch.items():
                batch[key] = val.to(device)
            # Generate outputs using a max_length suitable for short target texts
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16
            )
            # Decode the generated ids and ground-truth labels
            predictions = [tokenizer.decode(g, skip_special_tokens=True).strip().lower()
                           for g in generated_ids]
            targets = [tokenizer.decode(t, skip_special_tokens=True).strip().lower()
                       for t in batch["labels"]]
            for pred, target in zip(predictions, targets):
                total += 1
                if pred == target:
                    correct += 1
                if examples_printed < num_examples:
                  print(f"Example {examples_printed+1}:")
                  print(f"  Prediction: {pred}")
                  print(f"  Target:     {target}")
                  print("-" * 40)
                  examples_printed += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"AG News Test Accuracy: {accuracy*100:.2f}%")
    return accuracy

###################################################
# 4. Main Evaluation Script
###################################################
def main():
    model_name = "t5-small"
    # Load the AG News dataset from Hugging Face
    hf_dataset = load_dataset("ag_news")
    # Define label mapping (as used during training)
    label_mapping = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
    # Initialize tokenizer
    tokenizer = T5Tokenizer.from_pretrained(model_name)

    # Instantiate a T5 model and load the fine-tuned AG News state dict
    config = T5Config.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    state_dict = torch.load("t5_finetuned_amazon.pt", map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model = model.to(device)

    # Create the AG News test dataset and loader
    test_dataset = AGNewsDataset(hf_dataset, "test", tokenizer, label_mapping)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,
                             collate_fn=lambda batch: collate_fn_agnews(batch, tokenizer))

    # Evaluate the model
    evaluate_agnews(model, tokenizer, test_loader)

if __name__ == "__main__":
    main()

README.md:   0%|          | 0.00/8.07k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

  state_dict = torch.load("t5_finetuned_amazon.pt", map_location=device)
Evaluating AG News:   0%|          | 3/950 [00:00<01:15, 12.48batch/s]

Example 1:
  Prediction: world
  Target:     business
----------------------------------------
Example 2:
  Prediction: sci/tech
  Target:     sci/tech
----------------------------------------
Example 3:
  Prediction: sci/tech
  Target:     sci/tech
----------------------------------------
Example 4:
  Prediction: sci/tech
  Target:     sci/tech
----------------------------------------
Example 5:
  Prediction: sci/tech
  Target:     sci/tech
----------------------------------------


Evaluating AG News: 100%|██████████| 950/950 [01:01<00:00, 15.51batch/s]

AG News Test Accuracy: 90.13%





In [7]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###############################
# Evaluation for Normal T5 Model
# (t5_finetuned_agnews.pt on Amazon Reviews Test)
###############################

def evaluate_normal_t5_on_amazonreviews(model_path, num_examples=5):
    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    # Load the fully finetuned (normal) T5 model
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    # Load Amazon Reviews test set
    hf_dataset = load_dataset("amazon_polarity", split="test")
    # Define label mapping (0: negative, 1: positive)
    label_mapping = {0: "negative", 1: "positive"}

    class AmazonReviewsDataset(Dataset):
        def __init__(self, dataset, tokenizer, label_mapping):
            self.dataset = dataset.shuffle(seed=42).select(range(3600))
            self.tokenizer = tokenizer
            self.label_mapping = label_mapping
        def __len__(self):
            return len(self.dataset)
        def __getitem__(self, idx):
            sample = self.dataset[idx]
            text = sample["content"]
            label = sample["label"]
            input_text = "classify as positive or negative: " + text
            target_text = self.label_mapping[label]
            return input_text, target_text

    def collate_fn_amazon(batch):
        inputs, targets = zip(*batch)
        input_enc = tokenizer(list(inputs), padding=True, truncation=True, max_length=512, return_tensors="pt")
        target_enc = tokenizer(list(targets), padding=True, truncation=True, max_length=16, return_tensors="pt")
        input_enc["labels"] = target_enc["input_ids"]
        return input_enc

    test_dataset = AmazonReviewsDataset(hf_dataset, tokenizer, label_mapping)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn_amazon)

    total, correct = 0, 0
    examples_printed = 0
    for batch in tqdm(test_loader, desc="Evaluating Normal T5 on Amazon Reviews"):
        for k, v in batch.items():
            batch[k] = v.to(device)
        with torch.no_grad():
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16
            )
        # Decode predictions and targets (strip and lower for uniformity)
        preds = [tokenizer.decode(g, skip_special_tokens=True).strip().lower() for g in generated_ids]
        targets = [tokenizer.decode(t, skip_special_tokens=True).strip().lower() for t in batch["labels"]]
        for pred, target in zip(preds, targets):
            total += 1
            if pred == target:
                correct += 1
            if examples_printed < num_examples:
                print(f"Example {examples_printed+1}:")
                print(f"  Prediction: {pred}")
                print(f"  Target:     {target}")
                print("-" * 40)
                examples_printed += 1
    accuracy = correct / total if total > 0 else 0.0
    print(f"Normal T5 on Amazon Reviews Test Accuracy: {accuracy*100:.2f}%")


###############################
# Evaluation for T5-SVD Model
# (t5_svd_amazon.pt on AG News Test)
###############################

# (Assuming the T5WithSVD class and auto_generate_target_svd_config function are already defined above)
def evaluate_t5svd_on_agnews(model_path, num_examples=5):
    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)
    config.use_cache = False
    # Re-generate the same target SVD configuration used during training
    # base_model = T5ForConditionalGeneration.from_pretrained(model_name)
    # target_svd_config = auto_generate_target_svd_config(base_model)
    target_svd_config = load_svd_config("svd_config.json")
    # Instantiate the modified model
    model = T5WithSVD(config, svd_config=target_svd_config)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.reinitialize_svd()
    model.to(device)
    model.eval()

    # Load AG News test set
    ag_news = load_dataset("ag_news", split="test")
    # Define AG News label mapping
    label_mapping = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}

    class AGNewsDataset(Dataset):
        def __init__(self, dataset, tokenizer, label_mapping):
            self.dataset = dataset
            self.tokenizer = tokenizer
            self.label_mapping = label_mapping
        def __len__(self):
            return len(self.dataset)
        def __getitem__(self, idx):
            sample = self.dataset[idx]
            # For classification, we use a "classify:" prompt
            input_text = "classify: " + sample["text"]
            target_text = self.label_mapping[sample["label"]]
            return input_text, target_text

    def collate_fn_agnews(batch):
        inputs, targets = zip(*batch)
        input_enc = tokenizer(list(inputs), padding=True, truncation=True, max_length=512, return_tensors="pt")
        target_enc = tokenizer(list(targets), padding=True, truncation=True, max_length=16, return_tensors="pt")
        input_enc["labels"] = target_enc["input_ids"]
        return input_enc

    test_dataset = AGNewsDataset(ag_news, tokenizer, label_mapping)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn_agnews)

    total, correct = 0, 0
    examples_printed = 0
    for batch in tqdm(test_loader, desc="Evaluating T5-SVD on AG News"):
        for k, v in batch.items():
            batch[k] = v.to(device)
        with torch.no_grad():
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16,
                # max_new_tokens=1
            )
        preds = [tokenizer.decode(g, skip_special_tokens=True).strip().lower() for g in generated_ids]
        targets = [tokenizer.decode(t, skip_special_tokens=True).strip().lower() for t in batch["labels"]]
        for pred, target in zip(preds, targets):
            total += 1
            if pred == target:
                correct += 1
            if examples_printed < num_examples:
                print(f"Example {examples_printed+1}:")
                print(f"  Prediction: {pred}")
                print(f"  Target:     {target}")
                print("-" * 40)
                examples_printed += 1
    accuracy = correct / total if total > 0 else 0.0
    print(f"T5-SVD on AG News Test Accuracy: {accuracy*100:.2f}%")


###############################
# Main – Run Both Evaluations
###############################
if __name__ == "__main__":
    # Evaluate the normal model (trained on AGNews) on the Amazon Reviews test set.
    print("Evaluating normal T5 model (t5_finetuned_agnews.pt) on Amazon Reviews test set:")
    evaluate_normal_t5_on_amazonreviews("t5_finetuned_agnews.pt")

    # Evaluate the modified T5-SVD model (trained on Amazon Reviews) on the AG News test set.
    print("\nEvaluating T5-SVD model (t5_svd_amazon.pt) on AG News test set:")
    evaluate_t5svd_on_agnews("t5_svd_amazon.pt")

Evaluating normal T5 model (t5_finetuned_agnews.pt) on Amazon Reviews test set:


  state_dict = torch.load(model_path, map_location=device)
Evaluating Normal T5 on Amazon Reviews:   0%|          | 2/450 [00:00<00:46,  9.71it/s]

Example 1:
  Prediction: sci/tech
  Target:     positive
----------------------------------------
Example 2:
  Prediction: sci/tech
  Target:     negative
----------------------------------------
Example 3:
  Prediction: sci/tech
  Target:     negative
----------------------------------------
Example 4:
  Prediction: sci/tech
  Target:     negative
----------------------------------------
Example 5:
  Prediction: sports
  Target:     negative
----------------------------------------


Evaluating Normal T5 on Amazon Reviews: 100%|██████████| 450/450 [00:33<00:00, 13.26it/s]


Normal T5 on Amazon Reviews Test Accuracy: 0.00%

Evaluating T5-SVD model (t5_svd_amazon.pt) on AG News test set:
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.q.weight with top_k=248
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.k.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.v.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.o.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wi.weight with top_k=340
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wo.weight with top_k=1
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.q.weight with top_k=512
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.k.weight with top_k=512
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.v.weight with top_k=453
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.o.weight with top_k=1
[SVD Init] Decomposing encoder.block.1.lay

  state_dict = torch.load(model_path, map_location=device)


[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.q.weight with top_k=248
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.k.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.v.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.o.weight with top_k=1
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wi.weight with top_k=340
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wo.weight with top_k=1
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.q.weight with top_k=512
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.k.weight with top_k=512
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.v.weight with top_k=453
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.o.weight with top_k=1
[SVD Init] Decomposing encoder.block.1.layer.1.DenseReluDense.wi.weight with top_k=294
[SVD Init] Decomposing encoder.block.1.layer.1.DenseReluDense.wo.weig

Evaluating T5-SVD on AG News:   0%|          | 2/950 [00:00<02:15,  7.01it/s]

Example 1:
  Prediction: business
  Target:     business
----------------------------------------
Example 2:
  Prediction: sci/tech
  Target:     sci/tech
----------------------------------------
Example 3:
  Prediction: sci/tech
  Target:     sci/tech
----------------------------------------
Example 4:
  Prediction: sports
  Target:     sci/tech
----------------------------------------
Example 5:
  Prediction: sci/tech
  Target:     sci/tech
----------------------------------------


Evaluating T5-SVD on AG News: 100%|██████████| 950/950 [01:34<00:00, 10.08it/s]

T5-SVD on AG News Test Accuracy: 92.39%





In [None]:
###############################
# Evaluation for T5-SVD Model
# (t5_svd_amazon.pt on AG News Test)
###############################

# (Assuming the T5WithSVD class and auto_generate_target_svd_config function are already defined above)
def evaluate_t5svd_on_agnews(model_path, num_examples=5):
    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)
    config.use_cache = False
    # Re-generate the same target SVD configuration used during training
    base_model = T5ForConditionalGeneration.from_pretrained(model_name)
    target_svd_config = auto_generate_target_svd_config(base_model)
    # Instantiate the modified model
    model = T5WithSVD(config, svd_config=target_svd_config)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.to(device)
    model.eval()

    # Load Amazon Reviews test set
    hf_dataset = load_dataset("amazon_polarity", split="test")
    # Define label mapping (0: negative, 1: positive)
    label_mapping = {0: "negative", 1: "positive"}

    class AmazonReviewsDataset(Dataset):
        def __init__(self, dataset, tokenizer, label_mapping):
            self.dataset = dataset.shuffle(seed=42).select(range(3600))
            self.tokenizer = tokenizer
            self.label_mapping = label_mapping
        def __len__(self):
            return len(self.dataset)
        def __getitem__(self, idx):
            sample = self.dataset[idx]
            text = sample["content"]
            label = sample["label"]
            input_text = "classify as positive or negative: " + text
            target_text = self.label_mapping[label]
            return input_text, target_text

    def collate_fn_amazon(batch):
        inputs, targets = zip(*batch)
        input_enc = tokenizer(list(inputs), padding=True, truncation=True, max_length=512, return_tensors="pt")
        target_enc = tokenizer(list(targets), padding=True, truncation=True, max_length=16, return_tensors="pt")
        input_enc["labels"] = target_enc["input_ids"]
        return input_enc

    test_dataset = AmazonReviewsDataset(hf_dataset, tokenizer, label_mapping)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn_amazon)

    total, correct = 0, 0
    examples_printed = 0
    for batch in tqdm(test_loader, desc="Evaluating T5-SVD on AG News"):
        for k, v in batch.items():
            batch[k] = v.to(device)
        with torch.no_grad():
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16
            )
        preds = [tokenizer.decode(g, skip_special_tokens=True).strip().lower() for g in generated_ids]
        targets = [tokenizer.decode(t, skip_special_tokens=True).strip().lower() for t in batch["labels"]]
        for pred, target in zip(preds, targets):
            total += 1
            if pred == target:
                correct += 1
            if examples_printed < num_examples:
                print(f"Example {examples_printed+1}:")
                print(f"  Prediction: {pred}")
                print(f"  Target:     {target}")
                print("-" * 40)
                examples_printed += 1
    accuracy = correct / total if total > 0 else 0.0
    print(f"T5-SVD on AG News Test Accuracy: {accuracy*100:.2f}%")


###############################
# Main – Run Both Evaluations
###############################
if __name__ == "__main__":
    # Evaluate the modified T5-SVD model (trained on Amazon Reviews) on the AG News test set.
    print("\nEvaluating T5-SVD model (t5_svd_amazon.pt) on AG News test set:")
    evaluate_t5svd_on_agnews("t5_svd_amazon.pt")


Evaluating T5-SVD model (t5_svd_amazon.pt) on AG News test set:


model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.q.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.k.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.v.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.o.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wi.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wo.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.q.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.k.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.v.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.o.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.1.DenseReluDense.wi.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.1.DenseReluDen

  state_dict = torch.load(model_path, map_location=device)


README.md:   0%|          | 0.00/6.81k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/260M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/258M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/255M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/254M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/117M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3600000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/400000 [00:00<?, ? examples/s]

Evaluating T5-SVD on AG News:   0%|          | 2/450 [00:01<04:17,  1.74it/s]

Example 1:
  Prediction: negative
  Target:     positive
----------------------------------------
Example 2:
  Prediction: negative
  Target:     negative
----------------------------------------
Example 3:
  Prediction: negative
  Target:     negative
----------------------------------------
Example 4:
  Prediction: negative
  Target:     negative
----------------------------------------
Example 5:
  Prediction: negative
  Target:     negative
----------------------------------------


Evaluating T5-SVD on AG News: 100%|██████████| 450/450 [00:34<00:00, 13.07it/s]

T5-SVD on AG News Test Accuracy: 78.86%





In [None]:
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
config = T5Config.from_pretrained(model_name)
config.use_cache = False
# Re-generate the same target SVD configuration used during training
base_model = T5ForConditionalGeneration.from_pretrained(model_name)
target_svd_config = auto_generate_target_svd_config(base_model)
# Instantiate the modified model
model = T5WithSVD(config, svd_config=target_svd_config)
state_dict = torch.load('t5_svd_amazon.pt', map_location=device)
model.load_state_dict(state_dict, strict=False)

[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.q.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.k.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.v.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.0.SelfAttention.o.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wi.weight with top_k=256
[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wo.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.q.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.k.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.v.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.0.SelfAttention.o.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.1.DenseReluDense.wi.weight with top_k=256
[SVD Init] Decomposing encoder.block.1.layer.1.DenseReluDen

  state_dict = torch.load('t5_svd_amazon.pt', map_location=device)


_IncompatibleKeys(missing_keys=['encoder.block.0.layer.0.SelfAttention.q.weight', 'encoder.block.0.layer.0.SelfAttention.k.weight', 'encoder.block.0.layer.0.SelfAttention.v.weight', 'encoder.block.0.layer.0.SelfAttention.o.weight', 'encoder.block.0.layer.1.DenseReluDense.wi.weight', 'encoder.block.0.layer.1.DenseReluDense.wo.weight', 'encoder.block.1.layer.0.SelfAttention.q.weight', 'encoder.block.1.layer.0.SelfAttention.k.weight', 'encoder.block.1.layer.0.SelfAttention.v.weight', 'encoder.block.1.layer.0.SelfAttention.o.weight', 'encoder.block.1.layer.1.DenseReluDense.wi.weight', 'encoder.block.1.layer.1.DenseReluDense.wo.weight', 'encoder.block.2.layer.0.SelfAttention.q.weight', 'encoder.block.2.layer.0.SelfAttention.k.weight', 'encoder.block.2.layer.0.SelfAttention.v.weight', 'encoder.block.2.layer.0.SelfAttention.o.weight', 'encoder.block.2.layer.1.DenseReluDense.wi.weight', 'encoder.block.2.layer.1.DenseReluDense.wo.weight', 'encoder.block.3.layer.0.SelfAttention.q.weight', 'encod

In [None]:
import os
import json
import csv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

torch.autograd.set_detect_anomaly(True)

###################################################
# 1. Define a dataset (dummy example)
###################################################
class DummySeq2SeqDataset(Dataset):
    """
    A trivial dataset that returns (input_text, target_text).
    Replace with your real dataset.
    """
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]["input"], self.data[idx]["target"]

def collate_fn(batch, tokenizer, max_length=128):
    """
    Tokenize and prepare the batch for T5.
    """
    inputs, targets = zip(*batch)
    input_enc = tokenizer(list(inputs), padding=True, truncation=True, max_length=max_length, return_tensors="pt")
    target_enc = tokenizer(list(targets), padding=True, truncation=True, max_length=max_length, return_tensors="pt")

    # T5 uses 'labels' for the decoder
    return {
        "input_ids": input_enc["input_ids"],
        "attention_mask": input_enc["attention_mask"],
        "labels": target_enc["input_ids"],
    }

###################################################
# 2. Helper function for SVD and param management
###################################################

def decompose_weight_matrix(weight: torch.Tensor, top_k: int):
    """
    Perform SVD on a 2D weight matrix and split into:
      - top_k singular vectors (treated as frozen/buffers)
      - the rest (treated as trainable)
    Returns a dictionary containing:
      {
        "U_high": ...  # buffer
        "S_high": ...  # buffer
        "V_high": ...  # buffer
        "U_low": ...   # parameter
        "S_low": ...   # parameter
        "V_low": ...   # parameter
        "rank_high": top_k
      }
    """
    # SVD
    # shape of W = (out_features, in_features)
    device = weight.device
    W = weight.to(torch.float32)  # ensure float32 for SVD
    U, S, Vt = torch.linalg.svd(W, full_matrices=False)  # Vt has shape (in_features, in_features)

    # If top_k >= min(U.shape[1], Vt.shape[0]), clamp it
    k = min(top_k, S.shape[0])

    # High subspace (frozen)
    U_high = U[:, :k].detach()
    S_high = S[:k].detach()
    V_high = Vt[:k, :].detach()

    # Low subspace (trainable)
    U_low = U[:, k:].detach()
    S_low = S[k:].detach()
    V_low = Vt[k:, :].detach()

    # Move them to correct device
    U_high = U_high.to(device)
    S_high = S_high.to(device)
    V_high = V_high.to(device)
    U_low = U_low.to(device)
    S_low = S_low.to(device)
    V_low = V_low.to(device)

    # Wrap the "low" parts as parameters; "high" parts as buffers
    return {
        "U_high": U_high,  # no gradient
        "S_high": S_high,  # no gradient
        "V_high": V_high,  # no gradient
        "U_low": nn.Parameter(U_low),  # trainable
        "S_low": nn.Parameter(S_low),  # trainable
        "V_low": nn.Parameter(V_low),  # trainable
        "rank_high": k
    }


def reconstruct_weight_matrix(svd_dict):
    """
    Reconstruct the weight matrix from both high and low subspaces:
        W = U_high * diag(S_high) * V_high^T + U_low * diag(S_low) * V_low^T
    """
    U_high = svd_dict["U_high"]
    S_high = svd_dict["S_high"]
    V_high = svd_dict["V_high"]
    U_low = svd_dict["U_low"]
    S_low = svd_dict["S_low"]
    V_low = svd_dict["V_low"]

    if U_high.shape[1] > 0 and S_high.shape[0] > 0:
        high_part = torch.mm(U_high * S_high.unsqueeze(0), V_high)
    else:
        high_part = torch.zeros(U_low.size(0), V_low.size(1), device=U_high.device)

    # Reconstruct low part (with gradient)
    if U_low.shape[1] > 0 and S_low.shape[0] > 0:
        # Use explicit matrix multiplication for better gradient flow
        US_low = U_low * S_low.unsqueeze(0)
        low_part = torch.mm(US_low, V_low)
    else:
        low_part = torch.zeros(U_high.size(0), V_high.size(1), device=U_low.device)

    return high_part + low_part


def project_gradient_to_orthogonal_space(svd_dict):
    """
    Zero out gradients that lie in the direction of the high subspace for each param in the low subspace.
    In other words, ensure that d(U_low), d(S_low), d(V_low) are orthogonal to the subspace spanned by U_high, V_high.

    For example, we can do something like:
       dU_low = dU_low - (U_high * (U_high^T @ dU_low))
    to remove any components in the column space of U_high. Similarly for V_low.
    This is a simplistic approach.
    """
    # If there's no gradient, return
    if svd_dict["U_low"].grad is None and svd_dict["S_low"].grad is None and svd_dict["V_low"].grad is None:
        return

    U_high = svd_dict["U_high"]
    V_high = svd_dict["V_high"]

    # Project out from U_low.grad
    if svd_dict["U_low"].grad is not None:
        dU = svd_dict["U_low"].grad
        # Remove the component that lies in col-space of U_high
        # col-space of U_high is spanned by columns of U_high
        # We'll do: dU <- dU - U_high (U_high^T dU)
        proj = U_high @ (U_high.transpose(0,1) @ dU)
        dU.sub_(proj)  # in-place

    # Project out from V_low.grad
    if svd_dict["V_low"].grad is not None:
        dV = svd_dict["V_low"].grad
        # V_high has shape (k, in_features). The row-space is spanned by rows of V_high
        # We want to remove any component of dV that is in row-space of V_high.
        # row-space of V_high is spanned by each row vector -> equivalently col-space of V_high^T
        # So: dV <- dV - ( (dV V_high^T) V_high )
        # But we have to do that carefully with shapes.
        # One simpler approach is: for each row i of dV, project out from row i of V_high.
        # We'll do a matrix approach with unsqueeze expansions. This can get tricky.

        # Let's do it in a more direct manner:
        # We can think of dV as (r_low, c) shaped. row-space is dimension r_low, col c
        # row-space of V_high is dimension k. We want to remove the projection onto each row of V_high.
        # A quick hack: project columns in col-space of V_high^T. We'll treat each row as a vector,
        # so we do: dV <- dV - (dV (V_high^T V_high^\top)) ? Let's keep it simpler:

        # We'll treat each row of dV, call it dV_i. Each row of V_high, call it V_high_j
        # We remove for each j: (dV_i dot V_high_j) * V_high_j / (V_high_j dot V_high_j) if needed
        # But if V_high is orthonormal (which it should be from SVD) then we can do:
        #    dV_i <- dV_i - sum_j( dV_i dot V_high_j ) * V_high_j
        # This is effectively: dV <- dV - (dV V_high^T) V_high  (since V_high is orthonormal, V_high V_high^T = I_k)

        # We'll assume V_high is orthonormal from SVD (it should be).
        # Then the projection of dV onto row-space of V_high is (dV * V_high^T) * V_high
        # but we must be careful with matrix dims. Let's do it directly:

        proj = (dV @ V_high.transpose(0,1)) @ V_high
        dV.sub_(proj)

    # S_low is just diagonal elements (vector). The "direction" for S_high is also a vector, but we typically freeze S_high.
    # If you want to project dS_low to be orthogonal to S_high, that might or might not make sense. Usually you freeze S_high entirely (no param).
    # For simplicity, do nothing to dS_low here (or if you want to zero it if you consider them in same dimension).
    # We'll do nothing as they do not share "direction space" the same way U/V do.


###################################################
# 3. T5 Model subclass with SVD
###################################################

class T5WithSVD(T5ForConditionalGeneration):
    """
    Subclass of T5ForConditionalGeneration that:
      - On init, decomposes each (or selected) weight matrix via SVD.
      - Freezes the top subspace.
      - Registers the bottom subspace as trainable parameters.
      - On forward, reconstructs the full weight.
      - Optionally, does gradient projection to keep updates orthogonal to the frozen subspace.
    """
    def __init__(self, config: T5Config, svd_config=None):
        """
        svd_config: dict specifying how many top singular vectors to freeze
                    for each layer or each matrix name, e.g.:
                    {
                       "encoder.block.0.layer.0.DenseReluDense.wi.weight": 16,
                       "shared.embedding": 0, # skip or no decomposition
                       ...
                    }
        You might parse a CSV to build this dictionary.
        """
        super().__init__(config)
        self.svd_config = svd_config if svd_config is not None else {}

        self.name_mapping = {}

        # A dictionary to store the SVD decomposition for each param we want to handle:
        #   self.svd_params[name] = {
        #       "U_high": buffer,
        #       "S_high": buffer,
        #       "V_high": buffer,
        #       "U_low": Parameter,
        #       "S_low": Parameter,
        #       "V_low": Parameter,
        #       "rank_high": k
        #   }
        self.svd_params = nn.ModuleDict()

        # We run through named_parameters, pick which ones to decompose
        self._initialize_svd_parameters()

    def _initialize_svd_parameters(self):
        for name, param in list(self.named_parameters()):
            # Decide if we want to decompose this param
            # We only do SVD on 2D weight matrices, skip biases or embeddings with dimension > 2
            if len(param.shape) == 2 and name in self.svd_config and self.svd_config[name] > 0:
                top_k = self.svd_config[name]
                print(f"[SVD Init] Decomposing {name} with top_k={top_k}")

                # Decompose
                svd_dict = decompose_weight_matrix(param.data, top_k=top_k)

                # 1) Create a sanitized name for buffers (no periods allowed)
                safe_name = name.replace(".", "_")
                self.name_mapping[name] = safe_name

                # Register buffers + parameters
                self.register_buffer(f"{safe_name}_U_high", svd_dict["U_high"])
                self.register_buffer(f"{safe_name}_S_high", svd_dict["S_high"])
                self.register_buffer(f"{safe_name}_V_high", svd_dict["V_high"])

                module_svd = nn.Module()
                module_svd.U_low = nn.Parameter(svd_dict["U_low"])
                module_svd.S_low = nn.Parameter(svd_dict["S_low"])
                module_svd.V_low = nn.Parameter(svd_dict["V_low"])
                module_svd.rank_high = svd_dict["rank_high"]
                module_svd.safe_name = safe_name

                self.svd_params[safe_name] = module_svd

                # Remove the original param from the model's param list
                # (we do that by setting 'requires_grad=False' or something similar).
                param.requires_grad = False
            else:
                # Not decomposing this param
                # Freeze everything else by turning off gradients!
                param.requires_grad = False

    def _reconstruct_weight(self, name):
        """
        Reconstruct the full weight matrix from the stored SVD decomposition (if it exists)
        or return the original param if not decomposed.
        """
        if name in self.name_mapping:

            # Retrieve the sanitized name we stored
            safe_name = self.name_mapping[name]

            # Reconstruct from high + low
            U_high = getattr(self, f"{safe_name}_U_high")
            S_high = getattr(self, f"{safe_name}_S_high")
            V_high = getattr(self, f"{safe_name}_V_high")

            module_svd = self.svd_params[safe_name]
            U_low = module_svd.U_low
            S_low = module_svd.S_low
            V_low = module_svd.V_low

            # Build dict for reconstruct
            svd_dict = {
                "U_high": U_high,
                "S_high": S_high,
                "V_high": V_high,
                "U_low": U_low,
                "S_low": S_low,
                "V_low": V_low
            }
            W = reconstruct_weight_matrix(svd_dict)
            return W
        else:
            # Not decomposed, just return the original param
            return dict(self.named_parameters())[name]

    def forward(self, *args, **kwargs):
        """
        Override forward to:
          1) Reconstruct the decomposed weights on-the-fly.
             - We have to inject them into the right modules.
          2) Then call the standard T5 forward.
        """
        # Step 1: inject reconstructed weights
        # We'll do a simple approach: for each decomposed param, find the actual module
        # that uses that param, and set `module.weight.data = reconstructed`.
        # This is somewhat hacky because T5 can rename parameters. We'll do a naive approach
        # that works if the param name has a direct path.

        # for name in self.svd_params:
        #     # name might look like: "encoder.block.0.layer.0.DenseReluDense.wi.weight"
        #     # We need to locate this module. We can do so with a utility function:
        #     module, param_name = self._get_module_by_name(name)
        #     if module is not None and hasattr(module, param_name):
        #         W = self._reconstruct_weight(name)
        #         getattr(module, param_name).data.copy_(W)

        for original_name, safe_name in self.name_mapping.items():
            module, param_name = self._get_module_by_name(original_name)
            if module is not None and hasattr(module, param_name):
                W = self._reconstruct_weight(original_name)
                if param_name in module._parameters:
                    module._parameters.pop(param_name)
                setattr(module, param_name, W)
                # getattr(module, param_name).copy_(W)

        # Step 2: call the original forward
        return super().forward(*args, **kwargs)

    def _get_module_by_name(self, name):
        """
        Utility to retrieve the module object and the final parameter name
        from the "dot" path. For example:
            name="encoder.block.0.layer.0.DenseReluDense.wi.weight"
        We find `self.encoder.block[0].layer[0].DenseReluDense.wi` as the module,
        and return (module, "weight").
        """
        parts = name.split(".")
        # The last part is typically "weight" or "bias"
        param_name = parts[-1]
        module_parts = parts[:-1]

        # Start from 'self'
        mod = self
        for p in module_parts:
            if hasattr(mod, p):
                mod = getattr(mod, p)
            elif p.isdigit():
                # if it's an index for a list or nn.ModuleList
                mod = mod[int(p)]
            else:
                # can't find the path
                return None, None
        return mod, param_name

    def project_gradients(self):
        """
        After loss.backward(), call this to project gradients in the "low" subspace
        so that no component in the high subspace is updated.
        """
        for name, module_svd in self.svd_params.items():
            # Build an svd_dict for projection
            svd_dict = {
                "U_high": getattr(self, f"{name}_U_high"),
                "S_high": getattr(self, f"{name}_S_high"),
                "V_high": getattr(self, f"{name}_V_high"),
                "U_low": module_svd.U_low,
                "S_low": module_svd.S_low,
                "V_low": module_svd.V_low,
            }
            project_gradient_to_orthogonal_space(svd_dict)


###################################################
# 4. Example usage: training & inference
###################################################

def load_svd_config_from_csv(csv_path):
    """
    Suppose your CSV has columns: [param_name, top_k].
    We'll parse that into a dict: { param_name: top_k, ... }
    """
    svd_config = {}
    with open(csv_path, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            param_name = row["param_name"]
            top_k = int(row["top_k"])
            svd_config[param_name] = top_k
    return svd_config

def train_svd_model():
    ############################################################################
    # 1. Load or define your T5 config, tokenizer, and possibly a pretrained model
    ############################################################################
    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)

    # Example: let’s pretend we have a CSV that says which param_name -> top_k
    # Here we skip the real CSV reading for a quick example:
    # For demonstration, we'll pick one linear layer to freeze top 16 singular values
    # Adjust to your real CSV approach:
    svd_config = {
       "encoder.block.0.layer.1.DenseReluDense.wi.weight": 16,  # freeze top-16
       # ...
    }
    # Or do: svd_config = load_svd_config_from_csv("thresholds.csv")

    # Initialize our custom model
    model = T5WithSVD(config, svd_config=svd_config)

    for name, param in model.named_parameters():
      if param.requires_grad:
          print(f"Parameter {name} has gradients enabled")

    # Optionally load pretrained T5 weights
    pretrained_model = T5ForConditionalGeneration.from_pretrained(model_name)
    model.load_state_dict(pretrained_model.state_dict(), strict=False)
    model.to("cuda")

    ############################################################################
    # 2. Build a dataset & dataloader
    ############################################################################
    # Dummy data
    dummy_data = [
        {"input": "translate English to German: Hello world", "target": "Hallo Welt"},
        {"input": "translate English to German: I love cats", "target": "Ich liebe Katzen"},
        *(
        {"input": f"Sample input {i}", "target": f"Sample target {i}"}
        for i in range(10)
        )
        # ...
    ]
    dataset = DummySeq2SeqDataset(dummy_data)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True,
                            collate_fn=lambda x: collate_fn(x, tokenizer, max_length=32))

    ############################################################################
    # 3. Prepare optimizer
    ############################################################################
    # Notice that the high subspace is not in model.parameters(), only low is.
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)

    # Store original parameter states for the entire model
    original_params = {name: p.clone().detach() for name, p in model.named_parameters()}

    # Also store the SVD components for our single decomposed matrix
    matrix_name = "encoder.block.0.layer.1.DenseReluDense.wi.weight"
    safe_name = matrix_name.replace(".", "_")

    orig_U_high = model.get_buffer(f"{safe_name}_U_high").clone().detach()
    orig_S_high = model.get_buffer(f"{safe_name}_S_high").clone().detach()
    orig_V_high = model.get_buffer(f"{safe_name}_V_high").clone().detach()

    orig_U_low = model.svd_params[safe_name].U_low.clone().detach()
    orig_S_low = model.svd_params[safe_name].S_low.clone().detach()
    orig_V_low = model.svd_params[safe_name].V_low.clone().detach()

    ############################################################################
    # 4. Training loop
    ############################################################################
    for epoch in range(10):  # small epoch count for demonstration
        for batch in dataloader:
            for k, v in batch.items():
                batch[k] = v.cuda()

            model.train()

            outputs = model(**batch, use_cache=False)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()

            # Project out the gradients w.r.t. the high subspace
            model.project_gradients()

            optimizer.step()

            print(f"Epoch {epoch} - Loss: {loss.item()}")

    # After training, check if other parameters changed
    print("\n=== Checking Parameter Changes (Entire Model) ===")
    for name, p in model.named_parameters():
        diff = (p - original_params[name]).abs().sum().item()
        if diff != 0:
            print(f"[CHANGED] {name} sum of absolute diff = {diff:.6f}")
        else:
            print(f"[NO CHANGE] {name} is unchanged")

    # Now check specifically the SVD decomposition for 'matrix_name'
    print(f"\n=== Checking High vs. Low Subspace for '{safe_name}' ===")
    new_U_high = model.get_buffer(f"{safe_name}_U_high")
    new_S_high = model.get_buffer(f"{safe_name}_S_high")
    new_V_high = model.get_buffer(f"{safe_name}_V_high")

    new_U_low = model.svd_params[safe_name].U_low
    new_S_low = model.svd_params[safe_name].S_low
    new_V_low = model.svd_params[safe_name].V_low

    # 1) High subspace changes (should remain zero or near-zero)
    diff_U_high = (new_U_high - orig_U_high).abs().sum().item()
    diff_S_high = (new_S_high - orig_S_high).abs().sum().item()
    diff_V_high = (new_V_high - orig_V_high).abs().sum().item()
    print(f"[High Subspace] Δ||U_high||={diff_U_high}, Δ||S_high||={diff_S_high}, Δ||V_high||={diff_V_high}")

    # 2) Low subspace changes (we expect these to update)
    diff_U_low = (new_U_low - orig_U_low).abs().sum().item()
    diff_S_low = (new_S_low - orig_S_low).abs().sum().item()
    diff_V_low = (new_V_low - orig_V_low).abs().sum().item()
    print(f"[Low Subspace] Δ||U_low||={diff_U_low}, Δ||S_low||={diff_S_low}, Δ||V_low||={diff_V_low}")

    # Print final low singular values to see how they've changed
    print(f"\nOriginal S_low[:10]: {orig_S_low[:10].cpu().numpy()}")
    print(f"New      S_low[:10]: {new_S_low[:10].detach().cpu().numpy()}")

    # 3) Check orthogonality: low vectors remain orthogonal to high subspace
    #    We'll measure the norm of the overlap:
    u_dot = ((new_U_high.transpose(0,1) @ new_U_low)**2).mean().item()
    v_dot = ((new_V_high @ new_V_low.transpose(0,1))**2).mean().item()

    print(f"[Orthogonality] ||U_high^T U_low||={u_dot:.6f}, ||V_high V_low^T||={v_dot:.6f}")
    print("Done!")

    u_dot = ((new_U_high.transpose(0,1) @ new_U_high)**2).mean().item()
    v_dot = ((new_V_high @ new_V_high.transpose(0,1))**2).mean().item()

    print(f"[Orthogonality] ||U_high^T U_high||={u_dot:.6f}, ||V_high V_high^T||={v_dot:.6f}")
    print("Done!")

    u_dot = ((new_U_low.transpose(0,1) @ new_U_low)**2).mean().item()
    v_dot = ((new_V_low @ new_V_low.transpose(0,1))**2).mean().item()

    print(f"[Orthogonality] ||U_low^T U_low||={u_dot:.6f}, ||V_low V_low^T||={v_dot:.6f}")
    print("Done!")

    # Save model
    torch.save(model.state_dict(), "t5_svd_finetuned.pt")

    # # --- New code: Compare saved model (model1) with a newly reloaded model (model2) ---
    # print("\n=== Checking model equality after reload ===")

    # # Get model1's state dictionary
    # model1_state = model.state_dict()

    # # Create a new instance (model2) and load the saved state
    # model2 = T5WithSVD(config, svd_config=svd_config)
    # model2.load_state_dict(torch.load("t5_svd_finetuned.pt"), strict=False)
    # model2.to("cuda")
    # model2.eval()

    # # Get model2's state dictionary
    # model2_state = model2.state_dict()

    # # Compare each key from model1 with model2
    # all_equal = True
    # for key in model1_state:
    #     if key in model2_state:
    #         # Use torch.allclose for tensor values (with a small tolerance)
    #         if torch.is_tensor(model1_state[key]):
    #             equal = torch.allclose(model1_state[key], model2_state[key], atol=1e-6)
    #         else:
    #             equal = model1_state[key] == model2_state[key]
    #         print(f"Parameter '{key}': equal? {equal}")
    #         if not equal:
    #             all_equal = False
    #     else:
    #         print(f"Parameter '{key}' is missing in model2!")
    #         all_equal = False

    # if all_equal:
    #     print("All parameters and buffers match between model1 and model2!")
    # else:
    #     print("Some parameters differ between model1 and model2!")

def inference_svd_model():
    """
    Illustrates how to do inference.
    """
    model_name = "t5-small"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)

    # Disable caching at the config level
    config.use_cache = False

    # Suppose we have the same svd_config used in training
    svd_config = {
       "encoder.block.0.layer.1.DenseReluDense.wi.weight": 16,
    }
    model = T5WithSVD(config, svd_config=svd_config)
    # Load your fine-tuned weights
    model.load_state_dict(torch.load("t5_svd_finetuned.pt"), strict=False)
    model.to("cuda")
    model.eval()

    # Let's do a generation example
    input_text = "translate English to German: I really like pizza"
    input_enc = tokenizer([input_text], return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model.generate(**input_enc, max_length=40)
    print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))


if __name__ == "__main__":
    # Example: train then inference
    train_svd_model()
    inference_svd_model()

[SVD Init] Decomposing encoder.block.0.layer.1.DenseReluDense.wi.weight with top_k=16
Parameter svd_params.encoder_block_0_layer_1_DenseReluDense_wi_weight.U_low has gradients enabled
Parameter svd_params.encoder_block_0_layer_1_DenseReluDense_wi_weight.S_low has gradients enabled
Parameter svd_params.encoder_block_0_layer_1_DenseReluDense_wi_weight.V_low has gradients enabled
Epoch 0 - Loss: 5.541877746582031
Epoch 0 - Loss: 5.464041233062744
Epoch 0 - Loss: 5.573670387268066
Epoch 0 - Loss: 4.080054759979248
Epoch 0 - Loss: 5.005031585693359
Epoch 0 - Loss: 6.652746677398682
Epoch 1 - Loss: 4.666175842285156
Epoch 1 - Loss: 5.0689287185668945
Epoch 1 - Loss: 9.569939613342285
Epoch 1 - Loss: 6.041984558105469
Epoch 1 - Loss: 4.660211563110352
Epoch 1 - Loss: 6.375530242919922
Epoch 2 - Loss: 4.61644172668457
Epoch 2 - Loss: 4.938105583190918
Epoch 2 - Loss: 6.711968898773193
Epoch 2 - Loss: 4.150201797485352
Epoch 2 - Loss: 4.112115859985352
Epoch 2 - Loss: 5.201481819152832
Epoch 3 

  model2.load_state_dict(torch.load("t5_svd_finetuned.pt"), strict=False)


Parameter 'encoder_block_0_layer_1_DenseReluDense_wi_weight_U_high': equal? True
Parameter 'encoder_block_0_layer_1_DenseReluDense_wi_weight_S_high': equal? True
Parameter 'encoder_block_0_layer_1_DenseReluDense_wi_weight_V_high': equal? True
Parameter 'shared.weight': equal? True
Parameter 'encoder.embed_tokens.weight': equal? True
Parameter 'encoder.block.0.layer.0.SelfAttention.q.weight': equal? True
Parameter 'encoder.block.0.layer.0.SelfAttention.k.weight': equal? True
Parameter 'encoder.block.0.layer.0.SelfAttention.v.weight': equal? True
Parameter 'encoder.block.0.layer.0.SelfAttention.o.weight': equal? True
Parameter 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight': equal? True
Parameter 'encoder.block.0.layer.0.layer_norm.weight': equal? True
Parameter 'encoder.block.0.layer.1.DenseReluDense.wo.weight': equal? True
Parameter 'encoder.block.0.layer.1.layer_norm.weight': equal? True
Parameter 'encoder.block.1.layer.0.SelfAttention.q.weight': equal? True
Par

  model.load_state_dict(torch.load("t5_svd_finetuned.pt"), strict=False)


Generated: Ich mag Pizza


In [None]:
import torch
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration

model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
config = T5Config.from_pretrained(model_name)

# Optionally load pretrained T5 weights
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.to("cuda")

# --- Code to list all parameter names and shapes ---
for name, param in model.named_parameters():
    print(name, param.size())

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

shared.weight torch.Size([32128, 512])
encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([512, 512])
encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([512, 512])
encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([512, 512])
encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([512, 512])
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 8])
encoder.block.0.layer.0.layer_norm.weight torch.Size([512])
encoder.block.0.layer.1.DenseReluDense.wi.weight torch.Size([2048, 512])
encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([512, 2048])
encoder.block.0.layer.1.layer_norm.weight torch.Size([512])
encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([512, 512])
encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([512, 512])
encoder.block.1.layer.0.SelfAttention.v.weight torch.Size([512, 512])
encoder.block.1.layer.0.SelfAttention.o.weight torch.Size([512, 512])
encoder.block.1.layer.0.layer_norm.weight torc

In [1]:
import torch
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration

model_name = "t5-large"
tokenizer = T5Tokenizer.from_pretrained(model_name)
config = T5Config.from_pretrained(model_name)

# Optionally load pretrained T5 weights
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.to("cuda")

# --- Code to list all parameter names and shapes ---
for name, param in model.named_parameters():
    print(name, param.size())

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


shared.weight torch.Size([32128, 1024])
encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([1024, 1024])
encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([1024, 1024])
encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([1024, 1024])
encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([1024, 1024])
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 16])
encoder.block.0.layer.0.layer_norm.weight torch.Size([1024])
encoder.block.0.layer.1.DenseReluDense.wi.weight torch.Size([4096, 1024])
encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([1024, 4096])
encoder.block.0.layer.1.layer_norm.weight torch.Size([1024])
encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([1024, 1024])
encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([1024, 1024])
encoder.block.1.layer.0.SelfAttention.v.weight torch.Size([1024, 1024])
encoder.block.1.layer.0.SelfAttention.o.weight torch.Size([1024, 1024])
encoder.block.1.layer.0.