In [1]:
! pip install datasets

Defaulting to user installation because normal site-packages is not writeable
Collecting datasets
  Using cached datasets-3.2.0-py3-none-any.whl (480 kB)
Collecting dill<0.3.9,>=0.3.0
  Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Collecting fsspec[http]<=2024.9.0,>=2023.1.0
  Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)
     |████████████████████████████████| 179 kB 19.7 MB/s            
Collecting aiohttp
  Downloading aiohttp-3.11.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
     |████████████████████████████████| 1.6 MB 181.4 MB/s            
[?25hCollecting pandas
  Downloading pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)
     |████████████████████████████████| 13.1 MB 184.4 MB/s            
[?25hCollecting pyarrow>=15.0.0
  Downloading pyarrow-19.0.0-cp39-cp39-manylinux_2_28_x86_64.whl (42.1 MB)
     |████████████████████████████████| 42.1 MB 79.2 MB/s            
Collecting xxhash
  Downloading xxhash-3.5

In [3]:
#######################################
# Configurable parameters
#######################################
FINE_TUNE_DATASET = "yelp"                  # Which dataset to finetune on
STARTING_CHECKPOINT = "t5_finetuned_agnews.pt"            # Either a local checkpoint or a model name on Hugging Face
OUTPUT_MODEL_NAME = "t5_finetuned_yelp.pt"  # Name of the saved model after training

import os
import json
import csv
import time
import torch
from torch.nn import DataParallel
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset
from tqdm import tqdm
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#######################################
# 1) Prompt construction
#######################################
def construct_prompt(sample, dataset_name):
    """
    Builds the instruction-based prompt for different classification tasks.
    """
    dataset_name = dataset_name.lower()

    if dataset_name == "ag_news":
        return (
            "What is the topic of the following paragraph? "
            "Choose from [World, Sports, Business, Sci/Tech]. "
            + sample["text"]
        )
    elif dataset_name == "amazon":
        return (
            "What is the sentiment of the following paragraph? "
            "Choose from [negative, positive]. "
            + sample["content"]
        )
    elif dataset_name == "yelp":
        return (
            "What is the sentiment of the following paragraph? "
            "Choose from [1 star, 2 star, 3 star, 4 star, 5 star]. "
            + sample["text"]
        )
    elif dataset_name == "dbpedia":
        return (
            "What is the topic of the following paragraph? "
            "Choose from [Company, Educational Institution, Artist, Athlete, Office Holder, "
            "Mean of Transportation, Building, Natural Place, Village, Animal, Plant, Album, "
            "Film, Written Work]. "
            + sample["content"]
        )
    elif dataset_name == "yahoo":
        return (
            "What is the topic of the following paragraph? "
            "Choose from [Society, Science, Health, Education, Computer, Sports, Business, "
            "Entertainment, Relationship, Politics]. "
            + sample["question_title"] + " " + sample["question_content"]
        )
    # Keep any other tasks if you’d like, or remove them if not needed:
    # elif dataset_name == "mnli":
    #    ...
    else:
        # Default fallback
        return "classify dataset: " + sample.get("text", sample.get("content", ""))


#######################################
# 2) Dataset construction
#######################################
class GenericClassificationDataset(Dataset):
    """
    A generic dataset class that uses the construct_prompt function and a label mapping.
    """
    def __init__(self, hf_dataset, split, tokenizer, label_mapping, dataset_name):
        # Shuffle and subset to 3600 examples for demonstration
        # self.dataset = hf_dataset[split].shuffle(seed=42).select(range(3600))
        self.dataset = hf_dataset[split]
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping
        self.dataset_name = dataset_name

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        # Build the text prompt
        input_text = construct_prompt(sample, self.dataset_name)
        # Map numeric label to string label
        label_idx = sample.get("label", sample.get("topic"))
        label_str = self.label_mapping[label_idx]
        return input_text, label_str


def collate_fn_fn(batch, tokenizer, max_source_length=512, max_target_length=16):
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(
        list(inputs),
        padding=True,
        truncation=True,
        max_length=max_source_length,
        return_tensors="pt",
    )
    target_encodings = tokenizer(
        list(targets),
        padding=True,
        truncation=True,
        max_length=max_target_length,
        return_tensors="pt",
    )
    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings


#######################################
# 3) Training a normal T5 model
#######################################
def train_t5_model(
    fine_tune_dataset=FINE_TUNE_DATASET,
    starting_checkpoint=STARTING_CHECKPOINT,
    output_model_name=OUTPUT_MODEL_NAME,
    epochs=1,
    batch_size=64,
    lr=1e-3,
):
    """
    Loads a T5 model/ tokenizer, builds a train dataset & test dataset,
    then fine-tunes the model using standard cross-entropy.
    """
    # 1) Load the dataset from HF
    if fine_tune_dataset.lower() == "ag_news":
        hf_dataset = load_dataset("ag_news")
        label_mapping = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
    elif fine_tune_dataset.lower() == "amazon":
        hf_dataset = load_dataset("amazon_polarity")
        label_mapping = {0: "negative", 1: "positive"}
    elif fine_tune_dataset.lower() == "yelp":
        hf_dataset = load_dataset("yelp_review_full")
        label_mapping = {
            0: "1 star", 1: "2 star", 2: "3 star", 3: "4 star", 4: "5 star"
        }
    elif fine_tune_dataset.lower() == "dbpedia":
        hf_dataset = load_dataset("dbpedia_14")
        label_mapping = {
            0: "Company",
            1: "Educational Institution",
            2: "Artist",
            3: "Athlete",
            4: "Office Holder",
            5: "Mean of Transportation",
            6: "Building",
            7: "Natural Place",
            8: "Village",
            9: "Animal",
            10: "Plant",
            11: "Album",
            12: "Film",
            13: "Written Work",
        }
    elif fine_tune_dataset.lower() == "yahoo":
        hf_dataset = load_dataset("yahoo_answers_topics")
        label_mapping = {
            0: "Society",
            1: "Science",
            2: "Health",
            3: "Education",
            4: "Computer",
            5: "Sports",
            6: "Business",
            7: "Entertainment",
            8: "Relationship",
            9: "Politics",
        }
    else:
        raise ValueError(f"Unknown dataset: {fine_tune_dataset}")

    train_set = hf_dataset["train"]
    selected_train = []
    num_labels = len(label_mapping)
    for lab in range(num_labels):
        label_subset = train_set.filter(lambda x: x.get("label", x.get("topic")) == lab).shuffle(seed=42)
        selected_train.extend(label_subset.select(range(min(len(label_subset), 1000))))
    hf_dataset["train"] = selected_train

    # 2) Load the T5 tokenizer and config from a base model name (e.g. "t5-large")
    tokenizer = T5Tokenizer.from_pretrained("t5-large")
    config = T5Config.from_pretrained("t5-large")
    config.use_cache = False  # recommended to disable during training

    # 3) Create a T5ForConditionalGeneration model from base
    print(f"Loading T5 model from base 't5-large' ...")
    model = T5ForConditionalGeneration.from_pretrained("t5-large", config=config)
    model = model.to(device)

    model = DataParallel(model)

    # Now load your local ".pt" file into this base model
    if starting_checkpoint != "base":
        print(f"Loading state_dict from: {starting_checkpoint}")
        state_dict = torch.load(starting_checkpoint, map_location=device)
        model.load_state_dict(state_dict, strict=False)

    # 4) Build train / test datasets
    train_dataset = GenericClassificationDataset(
        hf_dataset, "train", tokenizer, label_mapping, fine_tune_dataset
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda b: collate_fn_fn(b, tokenizer),
    )

    # 5) Set up optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    # 6) Training Loop
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", leave=True)
        start_time = time.time()

        for batch in progress_bar:
            for k, v in batch.items():
                batch[k] = v.to(device)

            outputs = model(**batch)
            loss = outputs.loss
            loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            elapsed_time = time.time() - start_time
            remaining_time = elapsed_time / (progress_bar.n + 1) * (len(train_loader) - progress_bar.n)
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", eta=f"{remaining_time:.2f}s")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} - Average Loss: {avg_loss:.4f}")

    # 7) Save final model
    print(f"Saving final model to: {output_model_name}")
    torch.save(model.module.state_dict(), output_model_name)
    return model, tokenizer, hf_dataset, label_mapping


#######################################
# 4) Evaluation
#######################################
def evaluate_model(model, tokenizer, hf_dataset, label_mapping, dataset_name):
    model.eval()

    test_set = hf_dataset["test"]
    selected_test = []
    num_labels = len(label_mapping)
    for lab in range(num_labels):
        label_subset = test_set.filter(lambda x: x.get("label", x.get("topic")) == lab).shuffle(seed=42)
        selected_test.extend(label_subset.select(range(min(len(label_subset), 500))))
    hf_dataset["test"] = selected_test

    # Rebuild the test_dataset and test_loader with the new selection
    test_dataset = GenericClassificationDataset(
        hf_dataset, "test", tokenizer, label_mapping, dataset_name
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=64,
        shuffle=False,
        collate_fn=lambda b: collate_fn_fn(b, tokenizer),
    )

    total, correct = 0, 0

    for batch in tqdm(test_loader, desc="Evaluating", unit="batch"):
        for k, v in batch.items():
            batch[k] = v.to(device)

        with torch.no_grad():
            generated_ids = model.module.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16,
            )

        # Decode predictions and labels
        predictions = [
            tokenizer.decode(g, skip_special_tokens=True).strip().lower()
            for g in generated_ids
        ]
        targets = [
            tokenizer.decode(lbl, skip_special_tokens=True).strip().lower()
            for lbl in batch["labels"]
        ]

        for pred, target in zip(predictions, targets):
            total += 1
            if pred == target:
                correct += 1

    accuracy = correct / total if total > 0 else 0
    print(f"Test Accuracy: {accuracy * 100:.2f}%")


#######################################
# 5) Main
#######################################
if __name__ == "__main__":
    model, tokenizer, hf_dataset, label_mapping = train_t5_model(
        fine_tune_dataset=FINE_TUNE_DATASET,
        starting_checkpoint=STARTING_CHECKPOINT,
        output_model_name=OUTPUT_MODEL_NAME,
        epochs=3,        # Modify if you want more or fewer epochs
        batch_size=64,    # Modify your batch size
        lr=1e-3,         # Learning rate
    )
    evaluate_model(
        model,
        tokenizer,
        hf_dataset,
        label_mapping,
        dataset_name=FINE_TUNE_DATASET
    )


Loading T5 model from base 't5-large' ...
Loading state_dict from: t5_finetuned_agnews.pt


Epoch 1/3: 100%|██████████| 79/79 [01:08<00:00,  1.15batch/s, eta=0.87s, loss=0.3199] 


Epoch 1/3 - Average Loss: 0.6930


Epoch 2/3: 100%|██████████| 79/79 [01:08<00:00,  1.16batch/s, eta=0.86s, loss=0.1866] 


Epoch 2/3 - Average Loss: 0.2821


Epoch 3/3: 100%|██████████| 79/79 [01:08<00:00,  1.15batch/s, eta=0.87s, loss=0.2374] 


Epoch 3/3 - Average Loss: 0.2229
Saving final model to: t5_finetuned_yelp.pt


Evaluating: 100%|██████████| 40/40 [00:28<00:00,  1.38batch/s]

Test Accuracy: 59.68%





In [None]:
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration

# device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

########################################################################
# 1) Construct prompt (same as your training code)
########################################################################
def construct_prompt(sample, dataset_name):
    dataset_name = dataset_name.lower()
    if dataset_name == "ag_news":
        return (
            "What is the topic of the following paragraph? "
            "Choose from [World, Sports, Business, Sci/Tech]. "
            + sample["text"]
        )
    elif dataset_name == "amazon":
        return (
            "What is the sentiment of the following paragraph? "
            "Choose from [negative, positive]. "
            + sample["content"]
        )
    elif dataset_name == "yelp":
        return (
            "What is the sentiment of the following paragraph? "
            "Choose from [1 star, 2 star, 3 star, 4 star, 5 star]. "
            + sample["text"]
        )
    elif dataset_name == "dbpedia":
        return (
            "What is the topic of the following paragraph? "
            "Choose from [Company, Educational Institution, Artist, Athlete, Office Holder, "
            "Mean of Transportation, Building, Natural Place, Village, Animal, Plant, Album, "
            "Film, Written Work]. "
            + sample["content"]
        )
    elif dataset_name == "yahoo":
        return (
            "What is the topic of the following paragraph? "
            "Choose from [Society, Science, Health, Education, Computer, Sports, Business, "
            "Entertainment, Relationship, Politics]. "
            + sample["question_title"] + " " + sample["question_content"]
        )
    else:
        # Default fallback
        return "classify dataset: " + sample.get("text", sample.get("content", ""))


########################################################################
# 2) Evaluation Dataset
########################################################################
class EvaluationDataset(Dataset):
    """
    A simple evaluation dataset for text classification.
    This takes a single-split Hugging Face Dataset object (e.g. 'test' only)
    and constructs a prompt + label for each sample.
    """
    def __init__(self, hf_dataset, tokenizer, label_mapping, dataset_name):
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping
        
        new_data = []
        num_labels = len(label_mapping)
        for lab in range(num_labels):
            label_subset = hf_dataset.filter(lambda x: x.get("label", x.get("topic")) == lab).shuffle(seed=42)
            label_subset = label_subset.select(range(min(len(label_subset), 500)))
            for row in label_subset:
                new_data.append(row)
       
        hf_dataset = new_data
        
        # Convert each row into (prompt_text, label_text)
        self.examples = []
        for sample in hf_dataset:
            prompt = construct_prompt(sample, self.dataset_name)
            # The dataset's "label" might not always be present; fallback to "topic" if needed
            label_idx = sample.get("label", sample.get("topic"))
            
            # Convert numeric label into string label
            if label_idx is not None and label_idx in self.label_mapping:
                label_str = self.label_mapping[label_idx]
            else:
                label_str = "N/A"
            
            self.examples.append((prompt, label_str))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]  # (prompt, label)


def collate_fn_fn(batch, tokenizer, max_source_length=512, max_target_length=16):
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(
        list(inputs),
        padding=True,
        truncation=True,
        max_length=max_source_length,
        return_tensors="pt",
    )
    target_encodings = tokenizer(
        list(targets),
        padding=True,
        truncation=True,
        max_length=max_target_length,
        return_tensors="pt",
    )
    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings


########################################################################
# 3) Evaluate Model and Return Accuracy
########################################################################
def evaluate_model_return(model, tokenizer, test_loader):
    """
    Evaluates the model on a test loader and returns the accuracy.
    Also prints up to 5 (prediction, target) examples.
    """
    model.eval()
    total, correct = 0, 0

    printed = 0
    max_print = 5
    
    for batch in tqdm(test_loader, desc="Evaluating", unit="batch"):
        for key, val in batch.items():
            batch[key] = val.to(device)
        with torch.no_grad():
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16,
            )
        # Decode predictions and targets
        predictions = [tokenizer.decode(g, skip_special_tokens=True).strip().lower()
                       for g in generated_ids]
        targets = [tokenizer.decode(label, skip_special_tokens=True).strip().lower()
                   for label in batch["labels"]]
        
        for input_text, pred, target in zip(batch["input_ids"], predictions, targets):
            total += 1
            input_text_decoded = tokenizer.decode(input_text, skip_special_tokens=True).strip()

            if printed < max_print:
                print(f"\nExample {printed+1}")
                print(f"  Input:      {input_text_decoded}")  # Print model input
                print(f"  Predicted:  {pred}")  # Print model output
                print(f"  Actual:     {target}")  # Print ground truth label
                printed += 1

            if pred == target:
                correct += 1
    
    return correct / total if total > 0 else 0


########################################################################
# 4) Evaluate on All Tasks
########################################################################
def evaluate_on_all_tasks(model_checkpoint, dataset_infos):
    """
    Loads the final model from `model_checkpoint` (a .pt file) and evaluates it on each 
    task defined in `dataset_infos`. Prints accuracy for each task and overall average.
    """
    # 1) Load a base T5 config & tokenizer
    tokenizer = T5Tokenizer.from_pretrained("t5-large")
    config = T5Config.from_pretrained("t5-large")
    config.use_cache = False

    # 2) Create a base T5 model and load your .pt file
    print(f"Loading base T5 model...")
    model = T5ForConditionalGeneration.from_pretrained("t5-large", config=config).to(device)
    print(f"Loading state_dict from: {model_checkpoint}")
    state_dict = torch.load(model_checkpoint, map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    # 3) Evaluate on each dataset
    task_accuracies = {}
    for task_name, info in dataset_infos.items():
        print(f"\nEvaluating on {task_name} set:")

        # Load single-split dataset (e.g., "test")
        hf_dataset = load_dataset(
            info["hf_name"],
            split=info.get("split", "test"),
            **info.get("kwargs", {})
        )

        # Build evaluation dataset
        label_mapping = info["label_mapping"]
        eval_dataset = EvaluationDataset(
            hf_dataset,
            tokenizer,
            label_mapping,
            task_name
        )

        # Build dataloader
        test_loader = DataLoader(
            eval_dataset,
            batch_size=8,
            shuffle=False,
            collate_fn=lambda b: collate_fn_fn(b, tokenizer),
        )

        # Evaluate
        acc = evaluate_model_return(model, tokenizer, test_loader)
        task_accuracies[task_name] = acc
        print(f"{task_name} accuracy: {acc*100:.2f}%")

    # 4) Print overall average
    avg_acc = np.mean(list(task_accuracies.values()))
    print(f"\nAverage accuracy across all tasks: {avg_acc*100:.2f}%")
    return task_accuracies


########################################################################
# 5) Dictionary of tasks to evaluate
########################################################################
DATASET_INFOS = {
    "ag_news": {
        "hf_name": "ag_news",
        "split": "test",
        "label_mapping": {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
    },
    "amazon": {
        "hf_name": "amazon_polarity",
        "split": "test",
        "label_mapping": {0: "negative", 1: "positive"}
    },
    "yelp": {
        "hf_name": "yelp_review_full",
        "split": "test",
        "label_mapping": {0: "1 star", 1: "2 star", 2: "3 star", 3: "4 star", 4: "5 star"}
    },
    "dbpedia": {
        "hf_name": "dbpedia_14",
        "split": "test",
        "label_mapping": {
            0: "Company", 1: "Educational Institution", 2: "Artist", 3: "Athlete",
            4: "Office Holder", 5: "Mean of Transportation", 6: "Building",
            7: "Natural Place", 8: "Village", 9: "Animal", 10: "Plant",
            11: "Album", 12: "Film", 13: "Written Work"
        }
    },
    "yahoo": {
        "hf_name": "yahoo_answers_topics",
        "split": "test",
        "label_mapping": {
            0: "Society", 1: "Science", 2: "Health", 3: "Education",
            4: "Computer", 5: "Sports", 6: "Business", 7: "Entertainment",
            8: "Relationship", 9: "Politics"
        }
    },
    # (Optional) You can add more tasks here...
}


########################################################################
# 6) Main
########################################################################
if __name__ == "__main__":
    final_model_path = "t5_finetuned_yelp.pt"  # Update to your final .pt file
    all_accuracies = evaluate_on_all_tasks(final_model_path, DATASET_INFOS)

Loading base T5 model...
Loading state_dict from: t5_finetuned_yelp.pt

Evaluating on ag_news set:


Evaluating:   1%|          | 2/250 [00:00<00:31,  7.82batch/s]


Example 1
  Input:      What is the topic of the following paragraph? Choose from [World, Sports, Business, Sci/Tech]. US Hostage Apparently Beheaded (CBS/AP) A video posted on an Islamic Web site Monday shows the apparent beheading of a man identified in the tape as American construction contractor Eugene Armstrong.
  Predicted:  1 star
  Actual:     world

Example 2
  Input:      What is the topic of the following paragraph? Choose from [World, Sports, Business, Sci/Tech]. French Govt., Muslims Appeal for Reporters' Release PARIS (Reuters) - France's government and leaders of its Muslim minority urged Iraqi militants Sunday to free two French journalists they were holding hostage in a bid to force Paris to revoke its ban on Muslim headscarves in schools.
  Predicted:  2 star
  Actual:     world

Example 3
  Input:      What is the topic of the following paragraph? Choose from [World, Sports, Business, Sci/Tech]. Fired Ecuador justices are barred from offices QUITO, Ecuador -- Ecuado

Evaluating: 100%|██████████| 250/250 [00:31<00:00,  8.00batch/s]


ag_news accuracy: 0.00%

Evaluating on amazon set:


Evaluating:   1%|          | 1/125 [00:00<00:26,  4.71batch/s]


Example 1
  Input:      What is the sentiment of the following paragraph? Choose from [negative, positive]. WHEEL LOCKS WERE CHEAP AND SHIP QUICKLY THE ONLY PROBLEM WAS THAT I BASED THE SIZE ON THE SIZE BAR AND IT SAID THAT IT FIT SO I BOUGHT THEM AND COME TO FIND OUT THAT THEY WERE THE WRONG SIZE THE SIZE I NEEDED WAS A 14 X 1.5.. SO MAKE SURE THAT YOU KNOW THE SIZE BEFORE YOU BUY DONT TRUST THAT SIZE BAR.
  Predicted:  1 star
  Actual:     negative

Example 2
  Input:      What is the sentiment of the following paragraph? Choose from [negative, positive]. I spilled liquid and needed like 2 rolls of this rubbish to pick it up. I felt so bad about this crappy product, I went and planted a tree afterwards.
  Predicted:  1 star
  Actual:     negative

Example 3
  Input:      What is the sentiment of the following paragraph? Choose from [negative, positive]. My son thought this looked cool thru the box so Santa brought it. Turns out it the only action you can have it to push a button. My

Evaluating: 100%|██████████| 125/125 [00:22<00:00,  5.68batch/s]


amazon accuracy: 0.00%

Evaluating on yelp set:


Evaluating:   0%|          | 1/313 [00:00<01:12,  4.32batch/s]


Example 1
  Input:      What is the sentiment of the following paragraph? Choose from [1 star, 2 star, 3 star, 4 star, 5 star]. It's a shame, big banking has destroyed this bank. Fees and more fees. Stay away. BBVA Compass has joined the ranks of banking zombies, you are a number and and not a person. nnThe staff at this branch is misinformed and poorly communicated changes. I don't think they even knew what BBVA was doing next month for fees. $10 lost a 10 year customer.nnMy advice, never pay banking fees. Google Community Bank with your zip and you will find local banks that offer free accounts, free atms, and much more.
  Predicted:  1 star
  Actual:     1 star

Example 2
  Input:      What is the sentiment of the following paragraph? Choose from [1 star, 2 star, 3 star, 4 star, 5 star]. Hmmmm.....our experience tonight was very disappointing, so much so that I'm not even motivated to write a witty or sarcastic review so I'll just state the facts for our party of four adults.nn- 40

Evaluating: 100%|██████████| 313/313 [01:30<00:00,  3.47batch/s]


yelp accuracy: 59.68%

Evaluating on dbpedia set:


Evaluating:   0%|          | 1/875 [00:00<02:18,  6.30batch/s]


Example 1
  Input:      What is the topic of the following paragraph? Choose from [Company, Educational Institution, Artist, Athlete, Office Holder, Mean of Transportation, Building, Natural Place, Village, Animal, Plant, Album, Film, Written Work]. Samsung R&D Institute Delhi (SRI - Delhi) earlier known as Samsung India Software Center was set up as a 11th Software R & D Center for Samsung Electronics located in Noida . It was established in October 2002. Samsung carries out its R&D activities in India through SRI-Delhi and SRI - Bangalore .
  Predicted:  5 star
  Actual:     company

Example 2
  Input:      What is the topic of the following paragraph? Choose from [Company, Educational Institution, Artist, Athlete, Office Holder, Mean of Transportation, Building, Natural Place, Village, Animal, Plant, Album, Film, Written Work]. Play-Asia.com is an online retailer for entertainment products from Asia. The website sells import games DVDs music CDs gadgets groceries books gaming conso

Evaluating: 100%|██████████| 875/875 [02:19<00:00,  6.29batch/s]


dbpedia accuracy: 0.00%

Evaluating on yahoo set:


Evaluating:   0%|          | 1/625 [00:00<01:27,  7.10batch/s]


Example 1
  Input:      What is the topic of the following paragraph? Choose from [Society, Science, Health, Education, Computer, Sports, Business, Entertainment, Relationship, Politics]. PLACE where you like to go AGAIN AND AGAIN and stay there LONG? this place must be such that it is possible for you to visit it once in a week so no tourist place please, give your number or name 1 home 2 garden 3 playground 4school 5 collage 6 your work place 7 river bank 8 beach 9 religious place 10 club 11 hobby center 12 in wild/jungle
  Predicted:  4 star
  Actual:     society

Example 2
  Input:      What is the topic of the following paragraph? Choose from [Society, Science, Health, Education, Computer, Sports, Business, Entertainment, Relationship, Politics]. In some religions, the creation of the world was accomplished through a Sacrifice. What's the meaning of this? What does it mean for God himself to offer a sacrifice? Who is He offering it to? What's the meaning behind these stories of t

Evaluating: 100%|██████████| 625/625 [01:32<00:00,  6.74batch/s]

yahoo accuracy: 0.00%

Average accuracy across all tasks: 11.94%





In [1]:
import json

def save_svd_config(config, file_path="svd_config.json"):
    with open(file_path, "w") as f:
        json.dump(config, f, indent=4)

def load_svd_config(file_path="svd_config.json"):
    with open(file_path, "r") as f:
        return json.load(f)

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset
from tqdm import tqdm
import time
import json

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


###################################################
# 1. Define a PyTorch Dataset for DBpedia
###################################################
class DBpediaDataset(Dataset):
    """
    PyTorch dataset wrapper for the DBpedia dataset.
    Each example is converted to a text-to-text format.
    """
    def __init__(self, json_file, tokenizer):
        """
        hf_dataset: the Hugging Face dataset loaded via load_dataset("dbpedia_14")
        split: "train" or "test"
        tokenizer: a T5Tokenizer instance
        label_mapping: a dict mapping integer labels to string labels, e.g. {0:"Company", ...}
        """
        self.tokenizer = tokenizer

        # Load data from JSON file
        with open(json_file, "r", encoding="utf-8") as f:
            self.dataset = json.load(f)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get the sample
        sample = self.dataset[idx]
        input_text = (
            "What is the topic of the following paragraph? Choose from [Company, Educational Institution, Artist, "
            "Athlete, Office Holder, Mean of Transportation, Building, Natural Place, Village, Animal, Plant, "
            "Album, Film, Written Work]. "
            + sample["sentence"]
        )
        target_text = sample["label"]  # Keep the label as a string
        return input_text, target_text


###################################################
# 2. Collate Function
###################################################
def collate_fn(batch, tokenizer, max_source_length=512, max_target_length=16):
    """
    Tokenize the batch of input and target texts.
    """
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(list(inputs), padding=True, truncation=True, max_length=max_source_length, return_tensors="pt")
    target_encodings = tokenizer(list(targets), padding=True, truncation=True, max_length=max_target_length, return_tensors="pt")

    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings


###################################################
# 3. Training and Evaluation Functions
###################################################
def train_finetune_t5():
    # Define paths to JSON train and test files
    train_json_path = "/workspace/O-LoRA/CL_Benchmark/TC/dbpedia/train.json"
    test_json_path = "/workspace/O-LoRA/CL_Benchmark/TC/dbpedia/test.json"

    # Define the label mapping for DBpedia (string labels to numerical labels)
    label_mapping = {
        "Company": 0, "Educational Institution": 1, "Artist": 2, "Athlete": 3,
        "Office Holder": 4, "Mean of Transportation": 5, "Building": 6,
        "Natural Place": 7, "Village": 8, "Animal": 9, "Plant": 10,
        "Album": 11, "Film": 12, "Written Work": 13
    }

    # Load pretrained T5 tokenizer and model (T5-large)
    model_name = "t5-large"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    model = model.to(device)

    # Create PyTorch datasets for train and test splits
    train_dataset = DBpediaDataset(train_json_path, tokenizer)
    test_dataset = DBpediaDataset(test_json_path, tokenizer)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
                              collate_fn=lambda batch: collate_fn(batch, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,
                             collate_fn=lambda batch: collate_fn(batch, tokenizer))

    # Prepare optimizer (full fine-tuning; all model parameters are updated)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    num_epochs = 3

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", leave=True)
        start_time = time.time()

        for batch in progress_bar:
            # Move batch to device
            for key, val in batch.items():
                batch[key] = val.to(device)

            outputs = model(**batch)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Estimate time remaining
            elapsed_time = time.time() - start_time
            remaining_time = elapsed_time / (progress_bar.n + 1) * (len(train_loader) - progress_bar.n)
            progress_bar.set_postfix(loss=loss.item(), eta=f"{remaining_time:.2f}s")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

    # Save the fine-tuned model
    torch.save(model.state_dict(), "t5_finetuned_dbpedia.pt")
    print("Model saved as 't5_finetuned_dbpedia.pt'.")

    return model, tokenizer, test_loader


def evaluate(model, tokenizer, test_loader):
    """
    Evaluate the fine-tuned model on the test set.
    """
    model.eval()
    correct = 0
    total = 0
    sample_count = 0

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating", unit="batch"):
            for key, val in batch.items():
                batch[key] = val.to(device)
            # Generate predictions
            generated_ids = model.generate(batch["input_ids"],
                                           attention_mask=batch["attention_mask"],
                                           max_length=16)
            predictions = [tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_ids]
            # Decode the ground truth labels
            targets = [tokenizer.decode(t, skip_special_tokens=True).strip() for t in batch["labels"]]

            for pred, target in zip(predictions, targets):
                if pred.lower() == target.lower():
                    correct += 1
                total += 1
                # Print only the first 10 examples
                if sample_count < 5:
                    print(f"Target: {target} | Prediction: {pred}")
                    sample_count += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"Test Accuracy: {accuracy*100:.2f}%")
    return accuracy


###################################################
# 5. Main: Train, Check, and Evaluate
###################################################
if __name__ == "__main__":
    # Train and fine-tune T5 on DBpedia
    model1, tokenizer, test_loader = train_finetune_t5()

    # Evaluate the model
    evaluate(model1, tokenizer, test_loader)

Epoch 1/3: 100%|██████████| 1750/1750 [04:23<00:00,  6.65batch/s, eta=0.15s, loss=0.65]     


Epoch 1/3 - Average Loss: 0.4657


Epoch 2/3: 100%|██████████| 1750/1750 [04:23<00:00,  6.65batch/s, eta=0.15s, loss=0.519]  


Epoch 2/3 - Average Loss: 0.6570


Epoch 3/3: 100%|██████████| 1750/1750 [04:23<00:00,  6.65batch/s, eta=0.15s, loss=0.29]    


Epoch 3/3 - Average Loss: 0.4107
Model saved as 't5_finetuned_dbpedia.pt'.


Evaluating:   0%|          | 0/950 [00:00<?, ?batch/s]

Target: Office Holder | Prediction: Office Holder
Target: Album | Prediction: Album
Target: Plant | Prediction: Plant
Target: Building | Prediction: Building
Target: Natural Place | Prediction: Natural Place


Evaluating: 100%|██████████| 950/950 [01:35<00:00,  9.98batch/s]

Test Accuracy: 81.86%





In [3]:
# Configurable parameters:
SOURCE_SVD_DATASET = "dbpedia"       # Dataset to use when computing the adaptive SVD config.
FINE_TUNE_DATASET = "amazon"       # Fine-tuning dataset; options: "agnews", "amazon", "yelp", "dbpedia", "yahoo"
STARTING_CHECKPOINT = "t5_finetuned_dbpedia.pt"  # Path to the checkpoint you want to start from.
OUTPUT_MODEL_NAME = "t5_svd_amazon.pt"         # Name for the saved model after fine-tuning.

In [5]:
import os
import json
import csv
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
import numpy as np

torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def construct_prompt(sample, dataset_name):
    dataset_name = dataset_name.lower()
    if dataset_name == "agnews":
        # return "classify ag_news dataset: " + sample["text"]
        return (
            "What is the topic of the following paragraph? "
            "Choose from [World, Sports, Business, Sci/Tech]. "
            + sample["text"]
        )
    elif dataset_name == "amazon":
        # return "classify amazon dataset: " + sample["content"]
        return (
            "What is the sentiment of the following paragraph? "
            "Choose from [negative, positive]. "
            + sample["text"]
        )
    elif dataset_name == "yelp":
        # return "classify yelp dataset: " + sample["text"]
        return (
            "What is the sentiment of the following paragraph? "
            "Choose from [1 star, 2 star, 3 star, 4 star, 5 star]. "
            + sample["text"]
        )
    elif dataset_name == "dbpedia":
        # return "classify dbpedia dataset: " + sample["content"]
        return (
            "What is the topic of the following paragraph? "
            "Choose from [Company, Educational Institution, Artist, Athlete, Office Holder, "
            "Mean of Transportation, Building, Natural Place, Village, Animal, Plant, Album, "
            "Film, Written Work]. "
            + sample["text"]
        )
    elif dataset_name == "yahoo":
        # return "classify yahoo dataset: " + sample["question_title"] + " " + sample["question_content"]
        return (
            "What is the topic of the following paragraph? "
            "Choose from [Society, Science, Health, Education, Computer, Sports, Business, "
            "Entertainment, Relationship, Politics]. "
            + sample["question_title"] + " " + sample["question_content"]
        )
    elif dataset_name == "mnli":
        return "classify mnli dataset: premise: " + sample["premise"] + " hypothesis: " + sample["hypothesis"]
    elif dataset_name == "qqp":
        return "classify qqp dataset: question1: " + sample["question1"] + " question2: " + sample["question2"]
    elif dataset_name == "rte":
        return "classify rte dataset: sentence1: " + sample["sentence1"] + " sentence2: " + sample["sentence2"]
    elif dataset_name == "sst2":
        return "classify sst2 dataset: sentence: " + sample["sentence"]
    elif dataset_name == "wic":
        return "classify wic dataset: word: " + sample["word"] + " sentence1: " + sample["sentence1"] + " sentence2: " + sample["sentence2"]
    elif dataset_name == "cb":
        return "classify cb dataset: premise: " + sample["premise"] + " hypothesis: " + sample["hypothesis"]
    elif dataset_name == "copa":
        return "classify copa dataset: premise: " + sample["premise"] + " choice1: " + sample["choice1"] + " choice2: " + sample["choice2"]
    # elif dataset_name == "multirc":
    #     return "classify multirc dataset: question: " + sample["question"] + " passage: " + sample["passage"]
    elif dataset_name == "boolq":
        return "classify boolq dataset: question: " + sample["question"] + " passage: " + sample["passage"]
    elif dataset_name == "imdb":
        return "classify imdb dataset: " + sample["text"]
    else:
        return "classify dataset: " + sample.get("text", sample.get("content", ""))

###################################################
# 1. Helper Functions for SVD and Parameter Management
###################################################

def decompose_weight_matrix(weight: torch.Tensor, top_k: int):
    """
    Perform SVD on a 2D weight matrix and split into:
      - top_k singular vectors (treated as frozen/buffers)
      - the rest (treated as trainable)
    Returns a dictionary containing:
      {
        "U_high": ...  # buffer
        "S_high": ...  # buffer
        "V_high": ...  # buffer
        "U_low": ...   # parameter
        "S_low": ...   # parameter
        "V_low": ...   # parameter
        "rank_high": top_k
      }
    """
    device_local = weight.device
    W = weight.to(torch.float32)  # ensure float32 for SVD
    U, S, Vt = torch.linalg.svd(W, full_matrices=False)
    # Ensure we don’t ask for more than available
    k = min(top_k, S.shape[0])

    # High subspace (frozen)
    U_high = U[:, :k].detach().to(device_local)
    S_high = S[:k].detach().to(device_local)
    V_high = Vt[:k, :].detach().to(device_local)

    # Low subspace (trainable)
    U_low = U[:, k:].detach().to(device_local)
    S_low = S[k:].detach().to(device_local)
    V_low = Vt[k:, :].detach().to(device_local)

    return {
        "U_high": U_high,
        "S_high": S_high,
        "V_high": V_high,
        "U_low": nn.Parameter(U_low),
        "S_low": nn.Parameter(S_low),
        "V_low": nn.Parameter(V_low),
        "rank_high": k
    }


def reconstruct_weight_matrix(svd_dict):
    """
    Reconstruct the full weight matrix:
       W = U_high * diag(S_high) * V_high^T + U_low * diag(S_low) * V_low^T
    """
    U_high = svd_dict["U_high"]
    S_high = svd_dict["S_high"]
    V_high = svd_dict["V_high"]
    U_low = svd_dict["U_low"]
    S_low = svd_dict["S_low"]
    V_low = svd_dict["V_low"]

    if U_high.shape[1] > 0 and S_high.shape[0] > 0:
        high_part = torch.mm(U_high * S_high.unsqueeze(0), V_high)
    else:
        high_part = torch.zeros(U_low.size(0), V_low.size(1), device=U_high.device)

    if U_low.shape[1] > 0 and S_low.shape[0] > 0:
        US_low = U_low * S_low.unsqueeze(0)
        low_part = torch.mm(US_low, V_low)
    else:
        low_part = torch.zeros(U_high.size(0), V_high.size(1), device=U_low.device)

    return high_part + low_part


def check_reconstruction_error(weight, svd_dict, atol=1e-5):
    # Move the weight to the same device as the U_high buffer
    target_device = svd_dict["U_high"].device
    weight = weight.to(target_device)
    W_recon = reconstruct_weight_matrix(svd_dict)
    # Ensure reconstruction is also on the target device
    W_recon = W_recon.to(target_device)
    error = torch.norm(weight - W_recon) / torch.norm(weight)
    if error > atol:
        print(f"Warning: Reconstruction error {error:.2e} exceeds tolerance {atol}")
    return error


def project_gradient_to_orthogonal_space(svd_dict):
    """
    Remove from the gradients of the low subspace any component that lies
    in the high subspace.
    """
    if (svd_dict["U_low"].grad is None and
        svd_dict["S_low"].grad is None and
        svd_dict["V_low"].grad is None):
        return

    U_high = svd_dict["U_high"]
    V_high = svd_dict["V_high"]

    if svd_dict["U_low"].grad is not None:
        dU = svd_dict["U_low"].grad
        proj = U_high @ (U_high.transpose(0,1) @ dU)
        dU.sub_(proj)

    if svd_dict["V_low"].grad is not None:
        dV = svd_dict["V_low"].grad
        proj = (dV @ V_high.transpose(0,1)) @ V_high
        dV.sub_(proj)
    # We leave S_low unchanged


def compute_effective_rank(matrix):
    """
    Compute the effective rank of a matrix based on the definition provided.
    """
    _, S, _ = torch.linalg.svd(matrix, full_matrices=False)
    singular_values = S.cpu().numpy()

    # Compute the singular value distribution (p_k)
    l1_norm = np.sum(np.abs(singular_values))
    p_k = singular_values / l1_norm

    # Compute the Shannon entropy
    H = -np.sum(p_k * np.log(p_k + 1e-10))  # Add a small constant to avoid log(0)

    # Compute the effective rank
    effective_rank = np.exp(H)

    return effective_rank


###################################################
# 2. T5 Model Subclass with SVD (Only for Selected Parameters)
###################################################

class T5WithSVD(T5ForConditionalGeneration):
    """
    Subclass that, on initialization, decomposes selected weight matrices via SVD.
    Only parameters specified in the svd_config are decomposed.
    For each such 2D weight, we freeze the top singular vectors (50% by default)
    and register the lower half (trainable) as parameters.

    Additionally, we pre-compute the module mapping for faster weight injection.
    """
    def __init__(self, config: T5Config, svd_config=None, initialize_svd=True):
        super().__init__(config)
        # svd_config is a dict mapping full parameter names to top_k values.
        self.svd_config = svd_config if svd_config is not None else {}
        self.name_mapping = {}         # maps original name -> safe name
        self.svd_original_mapping = {} # maps safe name -> original name
        self.svd_params = nn.ModuleDict()
        self.svd_module_mapping = {}   # maps safe name -> (module, attribute_name)
        if initialize_svd:
          self._initialize_svd_parameters()

    def reinitialize_svd(self):
        """
        Reinitialize the SVD decomposition on the current (loaded) weights.
        Before reinitialization, store a copy of the original weights for each target parameter,
        then after reinitialization, check and print the reconstruction error.
        """
        # Save original weights for each parameter to be decomposed.
        self._original_weights = {}
        for orig_name in self.svd_config.keys():
            # Retrieve from the model's state_dict; ensure it is on the correct device.
            self._original_weights[orig_name] = self.state_dict()[orig_name].clone().to(device)

        # Clear previous SVD mappings.
        self.name_mapping = {}
        self.svd_original_mapping = {}
        self.svd_params = nn.ModuleDict()
        self.svd_module_mapping = {}
        # Reinitialize the SVD decomposition using the current weights.
        self._initialize_svd_parameters()

        # Now, for each decomposed parameter, compute and print the reconstruction error.
        for orig_name, safe_name in self.name_mapping.items():
            orig_weight = self._original_weights[orig_name]
            svd_dict = {
                "U_high": getattr(self, f"{safe_name}_U_high"),
                "S_high": getattr(self, f"{safe_name}_S_high"),
                "V_high": getattr(self, f"{safe_name}_V_high"),
                "U_low": self.svd_params[safe_name].U_low,
                "S_low": self.svd_params[safe_name].S_low,
                "V_low": self.svd_params[safe_name].V_low
            }
            error = check_reconstruction_error(orig_weight, svd_dict)
            print(f"Reconstruction error for {orig_name}: {error:.2e}")

    def _initialize_svd_parameters(self):
        # Iterate over all parameters
        for name, param in list(self.named_parameters()):
            if len(param.shape) == 2 and name in self.svd_config and self.svd_config[name] > 0:
                top_k = self.svd_config[name]
                print(f"[SVD Init] Decomposing {name} with top_k={top_k}")
                svd_dict = decompose_weight_matrix(param.data, top_k=top_k)
                safe_name = name.replace(".", "_")
                self.name_mapping[name] = safe_name
                self.svd_original_mapping[safe_name] = name

                # Compute the residual: the difference between the original weight and its SVD reconstruction.
                # residual = (param.data - reconstruct_weight_matrix(svd_dict)).detach()
                # Register the residual as a buffer (no gradients).
                # self.register_buffer(f"{safe_name}_residual", residual)

                # Register buffers for the high subspace
                self.register_buffer(f"{safe_name}_U_high", svd_dict["U_high"])
                self.register_buffer(f"{safe_name}_S_high", svd_dict["S_high"])
                self.register_buffer(f"{safe_name}_V_high", svd_dict["V_high"])

                # Create a module to hold the low subspace trainable parameters
                module_svd = nn.Module()
                module_svd.U_low = nn.Parameter(svd_dict["U_low"])
                module_svd.S_low = nn.Parameter(svd_dict["S_low"])
                module_svd.V_low = nn.Parameter(svd_dict["V_low"])
                module_svd.rank_high = svd_dict["rank_high"]
                module_svd.safe_name = safe_name
                self.svd_params[safe_name] = module_svd

                # Freeze the original parameter
                param.requires_grad = False

                # Pre-compute and store the module and attribute name for quick access
                mod, attr = self._get_module_by_name(name)
                if mod is not None:
                    self.svd_module_mapping[safe_name] = (mod, attr)
            # For parameters not in svd_config, leave them trainable (do nothing)

    def _reconstruct_weight(self, original_name):
        safe_name = self.name_mapping[original_name]
        U_high = getattr(self, f"{safe_name}_U_high")
        S_high = getattr(self, f"{safe_name}_S_high")
        V_high = getattr(self, f"{safe_name}_V_high")
        module_svd = self.svd_params[safe_name]
        U_low = module_svd.U_low
        S_low = module_svd.S_low
        V_low = module_svd.V_low
        svd_dict = {"U_high": U_high, "S_high": S_high, "V_high": V_high,
                    "U_low": U_low, "S_low": S_low, "V_low": V_low}
        W = reconstruct_weight_matrix(svd_dict)

        # Retrieve the residual that was stored during initialization.
        # residual = getattr(self, f"{safe_name}_residual").detach()

        # return W + residual

        return W

    def forward(self, *args, **kwargs):
        # Instead of recomputing the module mapping each time,
        # iterate over the precomputed svd_module_mapping.
        for safe_name, (module, attr) in self.svd_module_mapping.items():
            original_name = self.svd_original_mapping[safe_name]
            W = self._reconstruct_weight(original_name)
            # if attr in module._parameters:
            #     print(type(module._parameters))
            #     print(module._parameters)
            #     print(attr)
            #     module._parameters.pop(attr)
            # setattr(module, attr, W)
            # print(module._parameters)
            with torch.no_grad():
                getattr(module, attr).data.copy_(W)
        return super().forward(*args, **kwargs)

    def _get_module_by_name(self, name):
        """
        Given a full parameter name (e.g. "encoder.block.0.layer.0.SelfAttention.q.weight"),
        return (module, attribute_name) where module.attribute_name is that parameter.
        """
        parts = name.split(".")
        attr = parts[-1]
        mod = self
        for p in parts[:-1]:
            if hasattr(mod, p):
                mod = getattr(mod, p)
            elif p.isdigit():
                mod = mod[int(p)]
            else:
                return None, None
        return mod, attr

    def project_gradients(self):
        for safe_name, module_svd in self.svd_params.items():
            svd_dict = {
                "U_high": getattr(self, f"{safe_name}_U_high"),
                "S_high": getattr(self, f"{safe_name}_S_high"),
                "V_high": getattr(self, f"{safe_name}_V_high"),
                "U_low": module_svd.U_low,
                "S_low": module_svd.S_low,
                "V_low": module_svd.V_low,
            }
            project_gradient_to_orthogonal_space(svd_dict)

###################################################
# 3. Utility: Auto-generate SVD Config for Target Parameters
###################################################
def auto_generate_target_svd_config(model):
    """
    Given a model, generate an SVD configuration dictionary only for parameters that contain one of the
    following substrings:
      - SelfAttention.q.weight
      - SelfAttention.k.weight
      - SelfAttention.v.weight
      - SelfAttention.o.weight
      - DenseReluDense.wi.weight
      - DenseReluDense.wo.weight
    For each such 2D parameter, set:
         top_k = floor(min(dim0, dim1) / 2)
    """
    target_patterns = [
        "SelfAttention.q.weight",
        "SelfAttention.k.weight",
        "SelfAttention.v.weight",
        "SelfAttention.o.weight",
        "DenseReluDense.wi.weight",
        "DenseReluDense.wo.weight"
    ]
    config = {}
    for name, param in model.named_parameters():
        if any(pat in name for pat in target_patterns) and len(param.shape) == 2:
            # effective_rank = compute_effective_rank(param.data)
            # top_k = int(np.floor(effective_rank))
            # full_rank = min(param.shape)
            # if top_k > full_rank:
            #     top_k = full_rank
            # config[name] = top_k
            top_k = int(np.floor(max(param.shape)*0.2))
            full_rank = min(param.shape)
            if top_k > full_rank:
                top_k = full_rank
            config[name] = top_k
    # save_svd_config(config)
    return config

# def auto_generate_target_svd_config(model, tokenizer, n_samples=128, batch_size=8, num_batches=5, source_dataset=SOURCE_SVD_DATASET):
#     """
#     For each target parameter (matching target_patterns), compute the adaptive retention ratio based on
#     the importance I(W) measured using actual inputs from the AGNews test set.

#     For each target parameter W (shape: (d, m), let d = min(W.shape)).
#     For each such parameter:
#        - Run num_batches of AGNews test data through the model with hooks to capture the input X for
#          the module corresponding to W.
#        - Concatenate the captured X from all batches to form a matrix X of shape (m, total_samples).
#        - Compute I(W) = average cosine similarity between columns of X and Y = W @ X.
#     Then normalize importance by the mean and set:
#        CR(W) = 1 + (I(W)/mean(I(W)))*((d/2) - 1)
#        k = round(CR(W) * d / 2)
#     Clamp k between 1 and d.
#     Return a dictionary mapping parameter names to top_k.
#     """
#     target_patterns = [
#         "SelfAttention.q.weight",
#         "SelfAttention.k.weight",
#         "SelfAttention.v.weight",
#         "SelfAttention.o.weight",
#         "DenseReluDense.wi.weight",
#         "DenseReluDense.wo.weight"
#     ]
#     # Dictionary to store importance for each target parameter.
#     importance_dict = {}
#     # Dictionary to store captured inputs for each target parameter.
#     captured_inputs = {name: [] for name, param in model.named_parameters()
#                          if any(pat in name for pat in target_patterns) and len(param.shape)==2}

#     # Create hooks to capture inputs for each target module.
#     hooks = {}
#     def get_hook(name):
#         def hook(module, input, output):
#             # input[0] might have shape (batch_size, seq_length, in_features)
#             X = input[0]
#             # Flatten the batch and sequence dimensions into one:
#             X = X.reshape(-1, X.shape[-1])  # shape: (batch_size * seq_length, in_features)
#             # Transpose so that columns represent individual samples:
#             captured_inputs[name].append(X.transpose(0, 1).detach())
#         return hook

#     # For each target parameter, register a hook on its parent module.
#     for name, param in model.named_parameters():
#         if any(pat in name for pat in target_patterns) and len(param.shape)==2:
#             mod, attr = model._get_module_by_name(name)
#             if mod is not None:
#                 hooks[name] = mod.register_forward_hook(get_hook(name))

#     # Now run a few batches of test data from the dataset.
#     from datasets import load_dataset

#     # Load the chosen source dataset and build inputs appropriately.
#     if source_dataset.lower() == "ag_news":
#         dataset = load_dataset("ag_news", split="test")
#         inputs = [f"classify ag_news dataset: " + sample["text"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "amazon":
#         dataset = load_dataset("amazon_polarity", split="test")
#         inputs = [f"classify amazon dataset: " + sample["content"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "yelp":
#         dataset = load_dataset("yelp_review_full", split="test")
#         inputs = [f"classify yelp dataset: " + sample["text"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "dbpedia":
#         dataset = load_dataset("dbpedia_14", split="test")
#         inputs = [f"classify dbpedia dataset: " + sample["content"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "yahoo":
#         dataset = load_dataset("yahoo_answers_topics", split="test")
#         inputs = [f"classify yahoo dataset: " + sample["question_title"] + " " + sample["question_content"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "mnli":
#         dataset = load_dataset("glue", "mnli", split="validation_matched")
#         inputs = [f"classify mnli dataset: premise: " + sample["premise"] + " hypothesis: " + sample["hypothesis"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "qqp":
#         dataset = load_dataset("glue", "qqp", split="validation")
#         inputs = [f"classify qqp dataset: question1: " + sample["question1"] + " question2: " + sample["question2"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "rte":
#         dataset = load_dataset("glue", "rte", split="test")
#         inputs = [f"classify rte dataset: sentence1: " + sample["sentence1"] + " sentence2: " + sample["sentence2"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "sst2":
#         dataset = load_dataset("glue", "sst2", split="test")
#         inputs = [f"classify sst2 dataset: " + sample["sentence"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "wic":
#         dataset = load_dataset("super_glue", "wic", split="test")
#         inputs = [f"classify wic dataset: word: " + sample["word"] + " sentence1: " + sample["sentence1"] + " sentence2: " + sample["sentence2"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "cb":
#         dataset = load_dataset("super_glue", "cb", split="test")
#         inputs = [f"classify cb dataset: premise: " + sample["premise"] + " hypothesis: " + sample["hypothesis"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "copa":
#         dataset = load_dataset("super_glue", "copa", split="test")
#         inputs = [f"classify copa dataset: premise: " + sample["premise"] + " choice1: " + sample["choice1"] + " choice2: " + sample["choice2"] for sample in dataset.select(range(n_samples))]
#     # elif source_dataset.lower() == "multirc":
#     #     dataset = load_dataset("super_glue", "multirc", split="test")
#     #     inputs = [f"classify multirc dataset: question: " + sample["question"] + " passage: " + sample["passage"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "boolq":
#         dataset = load_dataset("super_glue", "boolq", split="test")
#         inputs = [f"classify boolq dataset: question: " + sample["question"] + " passage: " + sample["passage"] for sample in dataset.select(range(n_samples))]
#     elif source_dataset.lower() == "imdb":
#         dataset = load_dataset("imdb", split="test")
#         inputs = [f"classify imdb dataset: " + sample["text"] for sample in dataset.select(range(n_samples))]
#     else:
#         raise ValueError(f"Unknown source dataset: {source_dataset}")

#     encodings = tokenizer(inputs, padding=True, truncation=True, max_length=512, return_tensors="pt")
#     # Wrap the BatchEncoding in a custom Dataset
#     class BatchEncodingDataset(Dataset):
#         def __init__(self, encodings):
#             self.encodings = encodings
#         def __len__(self):
#             return self.encodings["input_ids"].shape[0]
#         def __getitem__(self, idx):
#             return {key: val[idx] for key, val in self.encodings.items()}

#     dataset = BatchEncodingDataset(encodings)
#     loader = DataLoader(dataset, batch_size=batch_size)
#     # agnews_loader = DataLoader(encodings, batch_size=batch_size)

#     model = model.to(device)
#     model.eval()
#     batches = 0
#     with torch.no_grad():
#         for batch in loader:
#             batch = {k: v.to(device) for k, v in batch.items()}
#             # _ = model(**batch)  # forward pass to trigger hooks
#             # batches += 1
#             # if batches >= num_batches:
#             #     break

#             batch_size = batch["input_ids"].shape[0]
#             # Create a dummy decoder input using the model's decoder_start_token_id.
#             # T5 usually uses 0 or the value from config.decoder_start_token_id.
#             dummy_decoder_input_ids = torch.full(
#                 (batch_size, 1),
#                 model.config.decoder_start_token_id,
#                 device=device,
#                 dtype=batch["input_ids"].dtype
#             )
#             # Forward pass with both encoder and decoder inputs.
#             _ = model(
#                 input_ids=batch["input_ids"],
#                 attention_mask=batch["attention_mask"],
#                 decoder_input_ids=dummy_decoder_input_ids
#             )

#     # Remove hooks.
#     for h in hooks.values():
#         h.remove()

#     # Now compute importance for each target parameter.
#     for name in captured_inputs.keys():
#         # Concatenate captured inputs along last dimension.
#         X = torch.cat(captured_inputs[name], dim=1).to(device)  # shape: (in_features, total_samples)
#         W = model.state_dict()[name].to(device)
#         Y = torch.mm(W, X)

#         # Determine m = min(W.shape) and slice both X and Y to the first m rows.
#         m = min(W.shape)
#         X_mod = X[:m, :]
#         Y_mod = Y[:m, :]

#         X_norm = X_mod / (torch.norm(X_mod, dim=0, keepdim=True) + 1e-8)
#         Y_norm = Y_mod / (torch.norm(Y_mod, dim=0, keepdim=True) + 1e-8)
#         cosine_sim = torch.sum(X_norm * Y_norm, dim=0)
#         I_W = torch.mean(cosine_sim).item()
#         importance_dict[name] = I_W

#     mean_importance = np.mean(list(importance_dict.values()))
#     config = {}
#     for name, param in model.named_parameters():
#         if name in importance_dict:
#             d = min(param.shape)
#             I_W = importance_dict[name]
#             I_n = I_W / (mean_importance + 1e-8)
#             mrr = d / 2.0 # 1.0
#             trr = d # d / 2.0
#             CR = mrr + I_n * (trr - mrr)
#             # As explained: full params of W is 2*d^2 (for square W) and retained params is 2*d*k,
#             # so we set k/d = CR  => k = CR * d.
#             # k = int(round(CR * d))
#             k = int(round(CR))
#             k = max(1, min(k, d))
#             config[name] = k
#     save_svd_config(config)
#     return config

###################################################
# 4. Dataset Construction
###################################################
class GenericClassificationDataset(Dataset):
    """
    A generic dataset that works for multiple classification datasets.
    Expects the HF dataset to have either "text" or "content" as the input field.
    The prompt is constructed as "classify {dataset_name} dataset: <input>"
    """
    def __init__(self, json_file, tokenizer, label_mapping, dataset_name):
        # self.dataset = hf_dataset[split].shuffle(seed=42).select(range(3600))

        self.tokenizer = tokenizer
        self.label_mapping = label_mapping
        self.dataset_name = dataset_name

        # Load data from JSON file
        with open(json_file, "r", encoding="utf-8") as f:
            self.dataset = json.load(f)
        
        # Ensure full dataset is used
        self.dataset = [
            {"text": sample["sentence"], "label": str(sample["label"])}  # Store labels as strings
            for sample in self.dataset
        ]

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        input_text = construct_prompt(sample, self.dataset_name)
        return input_text, sample["label"]

def collate_fn_fn(batch, tokenizer, max_source_length=512, max_target_length=16):
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(list(inputs), padding=True, truncation=True, max_length=max_source_length, return_tensors="pt")
    target_encodings = tokenizer([str(t) for t in targets], padding=True, truncation=True, max_length=max_target_length, return_tensors="pt")
    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings

###################################################

# 5. Training and Saving the SVD Model on Amazon Reviews
###################################################
def train_svd_model(fine_tune_dataset=FINE_TUNE_DATASET, starting_checkpoint=STARTING_CHECKPOINT, output_model_name=OUTPUT_MODEL_NAME):

    if fine_tune_dataset in ["agnews", "dbpedia", "yahoo"]:
        train_json_path = f"/workspace/O-LoRA/CL_Benchmark/TC/{fine_tune_dataset}/train.json"
        test_json_path = f"/workspace/O-LoRA/CL_Benchmark/TC/{fine_tune_dataset}/test.json"

    elif fine_tune_dataset in ["amazon", "yelp"]:
        train_json_path = f"/workspace/O-LoRA/CL_Benchmark/SC/{fine_tune_dataset}/train.json"
        test_json_path = f"/workspace/O-LoRA/CL_Benchmark/SC/{fine_tune_dataset}/test.json"

    from datasets import load_dataset
    if fine_tune_dataset.lower() == "agnews":
        label_mapping = {"World": 0, "Sports": 1, "Business": 2, "Science or Technology": 3}
    elif fine_tune_dataset.lower() == "amazon":
        label_mapping = {"negative": 0, "positive": 1}
    elif fine_tune_dataset.lower() == "yelp":
        label_mapping = {"1 star": 0, "2 star": 1, "3 star": 2, "4 star": 3, "5 star": 4}
    elif fine_tune_dataset.lower() == "dbpedia":
        label_mapping = {"Company": 0, "Educational Institution": 1, "Artist": 2, "Athlete": 3,
            "Office Holder": 4, "Mean of Transportation": 5, "Building": 6, "Natural Place": 7,
            "Village": 8, "Animal": 9, "Plant": 10, "Album": 11, "Film": 12, "Written Work": 13}
    elif fine_tune_dataset.lower() == "yahoo":
        label_mapping = {"Society": 0, "Science": 1, "Health": 2, "Education": 3, "Computer": 4,
            "Sports": 5, "Business": 6, "Entertainment": 7, "Relationship": 8, "Politics": 9}
    elif fine_tune_dataset.lower() == "mnli":
        label_mapping = {0: "entailment", 1: "neutral", 2: "contradiction"}
    elif fine_tune_dataset.lower() == "qqp":
        label_mapping = {0: "not duplicate", 1: "duplicate"}
    elif fine_tune_dataset.lower() == "rte":
        label_mapping = {0: "not entailment", 1: "entailment"}
    elif fine_tune_dataset.lower() == "sst2":
        label_mapping = {0: "negative", 1: "positive"}
    elif fine_tune_dataset.lower() == "wic":
        label_mapping = {0: "false", 1: "true"}
    elif fine_tune_dataset.lower() == "cb":
        label_mapping = {0: "contradiction", 1: "entailment", 2: "neutral"}
    elif fine_tune_dataset.lower() == "copa":
        label_mapping = {0: "choice1", 1: "choice2"}
    # elif fine_tune_dataset.lower() == "multirc":
    #   # You may need a custom mapping here.
        label_mapping = {}  
    elif fine_tune_dataset.lower() == "boolq":
        label_mapping = {0: "false", 1: "true"}
    elif fine_tune_dataset.lower() == "imdb":
        label_mapping = {0: "negative", 1: "positive"}
    else:
        raise ValueError(f"Unknown fine-tune dataset: {fine_tune_dataset}")
    
    # Use a prompt that indicates the dataset.
    dataset_prompt = fine_tune_dataset.lower()  # e.g., "dbpedia"

    model_name = "t5-large"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)
    config.use_cache = False  # disable cache for training

    # Load a base T5 model to auto-generate the target SVD config.
    # base_model = T5ForConditionalGeneration.from_pretrained(model_name)
    base_model = T5WithSVD(config, svd_config={}, initialize_svd=False)
    base_model.load_state_dict(torch.load(starting_checkpoint, map_location=device), strict=False)
    base_model = base_model.to(device)
    target_svd_config = auto_generate_target_svd_config(base_model)
    # target_svd_config = auto_generate_target_svd_config(base_model, tokenizer)
    print("Auto-generated target SVD config:")
    for k, v in target_svd_config.items():
        print(f"  {k}: freeze top {v} singular vectors")

    # Initialize our custom SVD model with target_svd_config.
    model = T5WithSVD(config, svd_config=target_svd_config, initialize_svd=False)
    # Load pretrained weights into our SVD model.
    model.load_state_dict(torch.load(starting_checkpoint, map_location=device), strict=False)
    model.reinitialize_svd()
    model = model.to(device)

    # # Load the original AGNews state dictionary
    # orig_state = torch.load('t5_finetuned_agnews.pt', map_location=device)

    # # For each parameter in the original state, compare with the corresponding effective parameter in model.
    # for name, orig_param in orig_state.items():
    #     # If this parameter was decomposed (present in our svd_config), then use our reconstruction function.
    #     if name in model.svd_config:
    #         # Compute effective weight from SVD (including residual)
    #         effective_weight = model._reconstruct_weight(name)
    #         # Compute relative error
    #         error = torch.norm(orig_param.to(device) - effective_weight) / torch.norm(orig_param.to(device))
    #         print(f"{name} (decomposed): relative error = {error.item():.2e}")
    #     else:
    #         # Otherwise, compare directly.
    #         try:
    #             model_param = model.state_dict()[name]
    #             error = torch.norm(orig_param.to(device) - model_param) / torch.norm(orig_param.to(device))
    #             print(f"{name} (not decomposed): relative error = {error.item():.2e}")
    #         except KeyError:
    #             print(f"{name} is not present in the current model state_dict.")

    # torch.save(model.state_dict(), "t5_svd_amazon.pt")

    # Create datasets and dataloaders
    train_dataset = GenericClassificationDataset(train_json_path, tokenizer, label_mapping, dataset_prompt)
    test_dataset = GenericClassificationDataset(test_json_path, tokenizer, label_mapping, dataset_prompt)


    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
                              collate_fn=lambda batch: collate_fn_fn(batch, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,
                             collate_fn=lambda batch: collate_fn_fn(batch, tokenizer))

    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    num_epochs = 3  # adjust as needed

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", leave=True)
        start_time = time.time()

        for batch in progress_bar:
            for key, val in batch.items():
                batch[key] = val.to(device)
            outputs = model(**batch, use_cache=False)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            model.project_gradients()  # ensure gradients remain in correct subspace
            optimizer.step()

            total_loss += loss.item()
            elapsed_time = time.time() - start_time
            remaining_time = elapsed_time / (progress_bar.n + 1) * (len(train_loader) - progress_bar.n)
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", eta=f"{remaining_time:.2f}s")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

    # Save the fine-tuned model (with SVD modifications)
    torch.save(model.state_dict(), output_model_name)
    print(f"Model saved as '{output_model_name}'")
    return model, tokenizer, test_loader

###################################################
# 6. Inference
###################################################
def inference_svd_model(output_model_name=OUTPUT_MODEL_NAME):
    model_name = "t5-large"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)
    config.use_cache = False
    # Re-generate the same target SVD configuration
    base_model = T5ForConditionalGeneration.from_pretrained(model_name)
    target_svd_config = auto_generate_target_svd_config(base_model)
    model = T5WithSVD(config, svd_config=target_svd_config)
    model.load_state_dict(torch.load(output_model_name), strict=False)
    model = model.to(device)
    model.eval()

    # Try a generation example – here we provide a sample review.
    input_text = "classify: This product exceeded my expectations and works perfectly!"
    input_enc = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model.generate(**input_enc, max_length=16)
    print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

###################################################
# 6. Evaluation on Test Set
###################################################
def evaluate_model(model, tokenizer, test_loader):
    model.eval()
    total, correct = 0, 0
    for batch in tqdm(test_loader, desc="Evaluating", unit="batch"):
        # Move batch tensors to device
        for key, val in batch.items():
            batch[key] = val.to(device)
        with torch.no_grad():
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16
            )
        # Decode predictions and targets
        predictions = [tokenizer.decode(g, skip_special_tokens=True).strip().lower()
                       for g in generated_ids]
        targets = [tokenizer.decode(label, skip_special_tokens=True).strip().lower()
                   for label in batch["labels"]]
        for pred, target in zip(predictions, targets):
            total += 1
            if pred == target:
                correct += 1
    accuracy = correct / total if total > 0 else 0
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

###################################################
# 7. Main
###################################################
if __name__ == "__main__":
    model, tokenizer, test_loader = train_svd_model(fine_tune_dataset=FINE_TUNE_DATASET, starting_checkpoint=STARTING_CHECKPOINT, output_model_name=OUTPUT_MODEL_NAME)
    evaluate_model(model, tokenizer, test_loader)
    # inference_svd_model()

Auto-generated target SVD config:
  encoder.block.0.layer.0.SelfAttention.q.weight: freeze top 204 singular vectors
  encoder.block.0.layer.0.SelfAttention.k.weight: freeze top 204 singular vectors
  encoder.block.0.layer.0.SelfAttention.v.weight: freeze top 204 singular vectors
  encoder.block.0.layer.0.SelfAttention.o.weight: freeze top 204 singular vectors
  encoder.block.0.layer.1.DenseReluDense.wi.weight: freeze top 819 singular vectors
  encoder.block.0.layer.1.DenseReluDense.wo.weight: freeze top 819 singular vectors
  encoder.block.1.layer.0.SelfAttention.q.weight: freeze top 204 singular vectors
  encoder.block.1.layer.0.SelfAttention.k.weight: freeze top 204 singular vectors
  encoder.block.1.layer.0.SelfAttention.v.weight: freeze top 204 singular vectors
  encoder.block.1.layer.0.SelfAttention.o.weight: freeze top 204 singular vectors
  encoder.block.1.layer.1.DenseReluDense.wi.weight: freeze top 819 singular vectors
  encoder.block.1.layer.1.DenseReluDense.wo.weight: freeze

Epoch 1/3: 100%|██████████| 625/625 [11:10<00:00,  1.07s/batch, eta=1.07s, loss=0.6024]  


Epoch 1/3 - Average Loss: 0.7511


Epoch 2/3: 100%|██████████| 625/625 [11:12<00:00,  1.08s/batch, eta=1.08s, loss=0.4790]  


Epoch 2/3 - Average Loss: 0.5677


Epoch 3/3: 100%|██████████| 625/625 [11:14<00:00,  1.08s/batch, eta=1.08s, loss=0.5952]  


Epoch 3/3 - Average Loss: 0.5586
Model saved as 't5_svd_amazon.pt'


Evaluating: 100%|██████████| 950/950 [03:10<00:00,  4.98batch/s]

Test Accuracy: 22.66%





In [6]:
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, concatenate_datasets
from transformers import T5Tokenizer, T5Config

# Suppose T5WithSVD, construct_prompt, device, collate_fn_fn, etc. are already defined above...

####################################################################
# 1) New Class: EvaluationDataset
####################################################################
class EvaluationDataset(Dataset):
    """
    A simple evaluation dataset for text classification.
    This takes a single-split Hugging Face Dataset object (e.g., 'test' only)
    and constructs a prompt + label for each sample.
    """
    def __init__(self, hf_dataset, tokenizer, label_mapping, dataset_name):
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping
        
        # Convert each row into (prompt_text, label_text)
        self.examples = []
        for sample in hf_dataset:
            prompt = construct_prompt(sample, self.dataset_name)
            
            # The dataset's "label" might not always be present, so fallback to "topic"
            label_idx = sample.get("label", sample.get("topic"))
            
            # Convert numeric label into string label
            if label_idx is not None and label_idx in self.label_mapping:
                label_str = self.label_mapping[label_idx]
            else:
                label_str = "N/A"
            
            self.examples.append((prompt, label_str))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]  # (prompt, label)


####################################################################
# 2) Evaluate Model and Return Accuracy
####################################################################
def evaluate_model_return(model, tokenizer, test_loader):
    """Evaluates the model on a test loader and returns the accuracy."""
    model.eval()
    total, correct = 0, 0

    printed = 0
    max_print = 5
    
    for batch in tqdm(test_loader, desc="Evaluating", unit="batch"):
        for key, val in batch.items():
            batch[key] = val.to(device)
        with torch.no_grad():
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=16
            )
        predictions = [
            tokenizer.decode(g, skip_special_tokens=True).strip().lower()
            for g in generated_ids
        ]
        targets = [
            tokenizer.decode(label, skip_special_tokens=True).strip().lower()
            for label in batch["labels"]
        ]
        for pred, target in zip(predictions, targets):
            total += 1

            # Print up to 5 examples of predicted vs. actual
            if printed < max_print:
                print(f"\nExample {printed+1}")
                print(f"  Predicted: {pred}")
                print(f"  Actual:    {target}")
                printed += 1

            if pred == target:
                correct += 1
    return correct / total if total > 0 else 0


####################################################################
# 3) Evaluate on All Tasks
####################################################################
def evaluate_on_all_tasks(model_checkpoint, dataset_infos):
    """
    Loads the final model from `model_checkpoint` and evaluates it on each 
    task defined in `dataset_infos`.
    Prints the accuracy for each task and the overall average accuracy.
    """
    # Load the final model.
    model_name = "t5-large"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    config = T5Config.from_pretrained(model_name)
    config.use_cache = False
    
    # Use your custom SVD model class
    model = T5WithSVD(config, svd_config={})
    model.load_state_dict(torch.load(model_checkpoint, map_location=device), strict=False)
    model = model.to(device)
    model.eval()

    task_accuracies = {}

    # Loop over each task in dataset_infos
    for task_name, info in dataset_infos.items():
        print(f"\nEvaluating on {task_name} set:")
        # 1) Load the dataset from HF
        #    e.g. "hf_name"="ag_news", "split"="test"
        hf_dataset = load_dataset(
            info["hf_name"],
            split=info.get("split", "test"),
            **info.get("kwargs", {})
        )

        # Overwrite hf_dataset with up to 500 examples per class
        test_data = hf_dataset  # hf_dataset is already your "test" split
        new_splits = []
        label_mapping = info["label_mapping"]
        for lab in sorted(label_mapping.keys()):
            label_subset = test_data.filter(lambda x: x.get("label", x.get("topic")) == lab)
            label_subset = label_subset.shuffle(seed=42)
            label_subset = label_subset.select(range(min(len(label_subset), 500)))
            new_splits.append(label_subset)

        # Merge each class subset into one dataset
        test_data = concatenate_datasets(new_splits)
        hf_dataset = test_data
        
        # 2) Build the evaluation dataset
        eval_dataset = EvaluationDataset(
            hf_dataset,
            tokenizer,
            label_mapping,
            task_name
        )
        
        # 3) Build DataLoader
        test_loader = DataLoader(
            eval_dataset,
            batch_size=8,
            shuffle=False,
            collate_fn=lambda batch: collate_fn_fn(batch, tokenizer)
        )
        
        # 4) Evaluate & store accuracy
        acc = evaluate_model_return(model, tokenizer, test_loader)
        task_accuracies[task_name] = acc
        print(f"{task_name} accuracy: {acc*100:.2f}%")

    # Compute average accuracy across tasks
    avg_acc = np.mean(list(task_accuracies.values()))
    print("\nAverage accuracy across all tasks: {:.2f}%".format(avg_acc*100))
    return task_accuracies

# Define a dictionary mapping task names to dataset info.
DATASET_INFOS = {
    # "ag_news": {
    #     "hf_name": "ag_news",
    #     "split": "test",
    #     "label_mapping": {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
    # },
    # "amazon": {
    #     "hf_name": "amazon_polarity",
    #     "split": "test",
    #     "label_mapping": {0: "negative", 1: "positive"}
    # },
    # "yelp": {
    #     "hf_name": "yelp_review_full",
    #     "split": "test",
    #     "label_mapping": {0: "1 star", 1: "2 star", 2: "3 star", 3: "4 star", 4: "5 star"}
    # },
    "dbpedia": {
        "hf_name": "dbpedia_14",
        "split": "test",
        "label_mapping": {0: "Company", 1: "Educational Institution", 2: "Artist",
                           3: "Athlete", 4: "Office Holder", 5: "Mean of Transportation",
                           6: "Building", 7: "Natural Place", 8: "Village",
                           9: "Animal", 10: "Plant", 11: "Album", 12: "Film", 13: "Written Work"}
    },
    # "yahoo": {
    #     "hf_name": "yahoo_answers_topics",
    #     "split": "test",
    #     "label_mapping": {0: "Society", 1: "Science", 2: "Health", 3: "Education",
    #                        4: "Computer", 5: "Sports", 6: "Business", 7: "Entertainment",
    #                        8: "Relationship", 9: "Politics"}
    # },
    # "mnli": {
    #     "hf_name": "glue",
    #     "kwargs": {"subset": "mnli"},
    #     "split": "test",
    #     "label_mapping": {0: "entailment", 1: "neutral", 2: "contradiction"}
    # },
    # "qqp": {
    #     "hf_name": "glue",
    #     "kwargs": {"subset": "qqp"},
    #     "split": "test",
    #     "label_mapping": {0: "not duplicate", 1: "duplicate"}
    # },
    # "rte": {
    #     "hf_name": "glue",
    #     "kwargs": {"subset": "rte"},
    #     "split": "test",
    #     "label_mapping": {0: "not entailment", 1: "entailment"}
    # },
    # "sst2": {
    #     "hf_name": "glue",
    #     "kwargs": {"subset": "sst2"},
    #     "split": "test",
    #     "label_mapping": {0: "negative", 1: "positive"}
    # },
    # "wic": {
    #     "hf_name": "super_glue",
    #     "kwargs": {"subset": "wic"},
    #     "split": "test",
    #     "label_mapping": {0: "false", 1: "true"}
    # },
    # "cb": {
    #     "hf_name": "super_glue",
    #     "kwargs": {"subset": "cb"},
    #     "split": "test",
    #     "label_mapping": {0: "contradiction", 1: "entailment", 2: "neutral"}
    # },
    # "copa": {
    #     "hf_name": "super_glue",
    #     "kwargs": {"subset": "copa"},
    #     "split": "test",
    #     "label_mapping": {0: "choice1", 1: "choice2"}
    # },
    # "boolq": {
    #     "hf_name": "super_glue",
    #     "kwargs": {"subset": "boolq"},
    #     "split": "test",
    #     "label_mapping": {0: "false", 1: "true"}
    # },
    # "imdb": {
    #     "hf_name": "imdb",
    #     "split": "test",
    #     "label_mapping": {0: "negative", 1: "positive"}
    # }
}

# Now call the evaluation function on the final model.
if __name__ == "__main__":
    final_model_path = "t5_finetuned_dbpedia.pt"  # Update this to your final model filename
    all_accuracies = evaluate_on_all_tasks(final_model_path, DATASET_INFOS)


Evaluating on dbpedia set:


Evaluating:   0%|          | 3/875 [00:00<01:07, 12.85batch/s]


Example 1
  Predicted: educational institution
  Actual:    company

Example 2
  Predicted: company
  Actual:    company

Example 3
  Predicted: company
  Actual:    company

Example 4
  Predicted: company
  Actual:    company

Example 5
  Predicted: company
  Actual:    company


Evaluating: 100%|██████████| 875/875 [01:03<00:00, 13.67batch/s]

dbpedia accuracy: 97.94%

Average accuracy across all tasks: 97.94%





In [2]:
import torch
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration

model_name = "t5-large"
tokenizer = T5Tokenizer.from_pretrained(model_name)
config = T5Config.from_pretrained(model_name)

# Optionally load pretrained T5 weights
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.to("cuda")

# --- Code to list all parameter names and shapes ---
for name, param in model.named_parameters():
    print(name, param.size())

shared.weight torch.Size([32128, 1024])
encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([1024, 1024])
encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([1024, 1024])
encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([1024, 1024])
encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([1024, 1024])
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 16])
encoder.block.0.layer.0.layer_norm.weight torch.Size([1024])
encoder.block.0.layer.1.DenseReluDense.wi.weight torch.Size([4096, 1024])
encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([1024, 4096])
encoder.block.0.layer.1.layer_norm.weight torch.Size([1024])
encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([1024, 1024])
encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([1024, 1024])
encoder.block.1.layer.0.SelfAttention.v.weight torch.Size([1024, 1024])
encoder.block.1.layer.0.SelfAttention.o.weight torch.Size([1024, 1024])
encoder.block.1.layer.0.

In [1]:
from torch.nn import DataParallel
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_dataset
from tqdm import tqdm
import time
import random

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


###################################################
# 1. Define a PyTorch Dataset for DBpedia
###################################################
class DBpediaDataset(Dataset):
    """
    PyTorch dataset wrapper for the DBpedia dataset.
    Each example is converted to a text-to-text format.
    """
    def __init__(self, hf_dataset, split, tokenizer, label_mapping):
        """
        hf_dataset: the Hugging Face dataset loaded via load_dataset("dbpedia_14")
        split: "train" or "test"
        tokenizer: a T5Tokenizer instance
        label_mapping: a dict mapping integer labels to string labels, e.g. {0:"Company", ...}
        """
        self.dataset = hf_dataset[split]
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get the sample
        sample = self.dataset[idx]
        text = sample["content"]  # DBpedia has "content" as the text field
        label_id = sample["label"]

        # Create input prompt and target text
        input_text = "What is the topic of the following paragraph? Choose from [Company, Educational Institution, Artist, Athlete, Office Holder, Mean of Transportation, Building, Natural Place, Village, Animal, Plant, Album, Film, Written Work]. " + text
        target_text = self.label_mapping[label_id]
        return input_text, target_text


###################################################
# 2. Collate Function
###################################################
def collate_fn(batch, tokenizer, max_source_length=512, max_target_length=16):
    """
    Tokenize the batch of input and target texts.
    """
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(list(inputs), padding=True, truncation=True, max_length=max_source_length, return_tensors="pt")
    target_encodings = tokenizer(list(targets), padding=True, truncation=True, max_length=max_target_length, return_tensors="pt")

    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings


###################################################
# 3. Training and Evaluation Functions
###################################################
def train_finetune_t5():
    # Load the DBpedia dataset from Hugging Face
    hf_dataset = load_dataset("dbpedia_14")
    # For each of the 14 classes, shuffle and select up to 1000 examples
    train_set = hf_dataset["train"]
    selected_train = []
    for lab in range(14):
        label_subset = train_set.filter(lambda x: x["label"] == lab).shuffle(seed=42)
        selected_train.extend(label_subset.select(range(min(len(label_subset), 1000))))
    hf_dataset["train"] = selected_train

    # Define the label mapping for the 14 classes in DBpedia
    label_mapping = {
        0: "Company", 1: "Educational Institution", 2: "Artist", 3: "Athlete",
        4: "Office Holder", 5: "Mean of Transportation", 6: "Building",
        7: "Natural Place", 8: "Village", 9: "Animal", 10: "Plant",
        11: "Album", 12: "Film", 13: "Written Work"
    }

    # Load pretrained T5 tokenizer and model (T5-large)
    model_name = "t5-large"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    model = model.to(device)

    model = DataParallel(model)

    # Create PyTorch datasets for train and test splits
    train_dataset = DBpediaDataset(hf_dataset, "train", tokenizer, label_mapping)
    test_dataset = DBpediaDataset(hf_dataset, "test", tokenizer, label_mapping)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,
                              collate_fn=lambda batch: collate_fn(batch, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False,
                             collate_fn=lambda batch: collate_fn(batch, tokenizer))

    # Prepare optimizer (full fine-tuning; all model parameters are updated)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    num_epochs = 1

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", leave=True)
        start_time = time.time()

        for batch in progress_bar:
            # Move batch to device
            for key, val in batch.items():
                batch[key] = val.to(device)

            outputs = model(**batch)
            loss = outputs.loss

            loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Estimate time remaining
            elapsed_time = time.time() - start_time
            remaining_time = elapsed_time / (progress_bar.n + 1) * (len(train_loader) - progress_bar.n)
            progress_bar.set_postfix(loss=loss.item(), eta=f"{remaining_time:.2f}s")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

    # Save the fine-tuned model
    torch.save(model.module.state_dict(), "t5_finetuned_dbpedia.pt")
    print("Model saved as 't5_finetuned_dbpedia.pt'.")

    return model, tokenizer, test_loader


def evaluate(model, tokenizer, test_loader):
    """
    Evaluate the fine-tuned model on the test set.
    """
    model.eval()
    correct = 0
    total = 0

    # For each of the 14 classes, shuffle and select up to 500 examples
    test_data = list(test_loader.dataset.dataset)  # the raw HF test set
    selected_test = []
    for lab in range(14):
        label_subset = [ex for ex in test_data if ex["label"] == lab]
        random.shuffle(label_subset)
        selected_test.extend(label_subset[:500])

    # Overwrite hf_dataset["test"] with the per-label subset
    test_loader.dataset.dataset = selected_test

    # Rebuild the DataLoader for evaluation
    test_loader = DataLoader(
        test_loader.dataset,  # same DBpediaDataset wrapper
        batch_size=64,
        shuffle=False,
        collate_fn=lambda batch: collate_fn(batch, tokenizer)
    )

    with torch.no_grad():
        for batch in test_loader:
            for key, val in batch.items():
                batch[key] = val.to(device)
            # Generate predictions
            generated_ids = model.module.generate(batch["input_ids"],
                                           attention_mask=batch["attention_mask"],
                                           max_length=16)
            predictions = [tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_ids]
            # Decode the ground truth labels
            targets = [tokenizer.decode(t, skip_special_tokens=True).strip() for t in batch["labels"]]

            for pred, target in zip(predictions, targets):
                if pred.lower() == target.lower():
                    correct += 1
                total += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"Test Accuracy: {accuracy*100:.2f}%")
    return accuracy


###################################################
# 5. Main: Train, Check, and Evaluate
###################################################
if __name__ == "__main__":
    # Train and fine-tune T5 on DBpedia
    model1, tokenizer, test_loader = train_finetune_t5()

    # Evaluate the model
    evaluate(model1, tokenizer, test_loader)

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Epoch 1/1:   0%|          | 0/219 [00:00<?, ?batch/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Epoch 1/1: 100%|██████████| 219/219 [03:06<00:00,  1.17batch/s, eta=0.85s, loss=0.0626]   


Epoch 1/1 - Average Loss: 0.2419
Model saved as 't5_finetuned_dbpedia.pt'.
Test Accuracy: 98.66%
