Skip to content

RajWorking/taskweaver

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

69 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TaskWeaver: Instance-Level Dynamic Model Adaptation via Hypernetworks

Course Project for CMU 10-623: Generative AI (Fall 2025)

TaskWeaver is a research implementation exploring hypernetwork-based approaches for instance-level dynamic language model adaptation. Unlike task-level approaches (e.g., Sakana AI's Text-to-LoRA) that generate fixed LoRA weights per task, TaskWeaver generates unique LoRA weights for each individual input on-the-fly based on the input prompt's semantic embedding. By training on diverse datasets, the hypernetwork learns to generalize and optimize model weights to maximize performance for any given input.

Table of Contents

Overview

TaskWeaver addresses the challenge of efficient adaptation in language models by going beyond task-level optimization. Existing approaches include:

  • Static models: Use a single model for all inputs (no adaptation)
  • Task-specific LoRA adapters: Train separate adapters for each task (storage overhead, requires task identification)
  • Task-level hypernetworks (e.g., Text-to-LoRA): Generate one set of LoRA weights per task

TaskWeaver proposes a more granular approach: learn a hypernetwork that generates unique LoRA weights for each individual input based on its semantic content. This enables:

  • Instance-level adaptation: Every input receives custom LoRA weights optimized for that specific prompt
  • Powerful generalization: Trained on diverse datasets to handle any input distribution
  • Parameter efficiency: Single hypernetwork serves all tasks and inputs
  • No task identification needed: Adaptation is purely based on semantic content of the input prompt, not task labels

Key Features

  • Instance-Level Adaptation: Generates unique LoRA weights for every individual input, not just per-task
  • Semantic Conditioning: Uses each prompt's semantic embedding to create input-specific adaptations
  • Beyond Text-to-LoRA: More granular than task-level approaches, enabling finer-grained optimization
  • Multi-Dataset Training: Trains on diverse dataset mixtures to learn generalizable adaptation strategies
  • Flexible Architecture: Works with ANY transformer based language model (our experiments include Pythia, Gemma 3 and Qwen 3), but can be adapted to work with anything that LoRA can be applied to.
  • Comprehensive Evaluation: Systematic comparison across base models, static LoRA, and TaskWeaver
  • Modular Design: Clean abstractions for datasets, models, and evaluation
  • Analysis: The generated LoRA weights are analyzed using T-SNE

Architecture

TaskWeaver consists of three main components:

1. DynamicLoraLinear Layer

Custom linear layer that supports batch-specific LoRA adaptations:

output = Wx + b + (α/r) * BAx

where B and A are instance-specific low-rank matrices generated by the hypernetwork.

2. TaskWeaver Hypernetwork

Generates instance-specific LoRA weights for each input using:

  • Semantic Embedding Extraction: Extracts hidden state from the last prompt token for each input
  • Semantic Projection: Projects LM hidden states to hypernetwork embedding space
  • Contextual Embeddings: Layer, module, and matrix type embeddings for structural context
  • MLP: 2-layer network with GELU activations and LayerNorm
  • Output Heads: Separate linear heads for A and B matrices of each target module

This architecture ensures every input gets its own unique set of LoRA weights based on its semantic content.

3. Training Pipeline

  • Freezes base language model parameters
  • Trains only the hypernetwork components
  • Uses prompt length tracking to extract semantic embeddings from the last prompt token of each instance
  • Trains on diverse, mixed datasets to learn generalizable adaptation strategies
  • Supports configurable dataset interleaving and stopping strategies

Comparison to Text-to-LoRA

TaskWeaver builds upon and extends the hypernetwork-for-LoRA paradigm introduced by Sakana AI's Text-to-LoRA work (see refs/ directory). Key differences:

Aspect Text-to-LoRA (Sakana AI) TaskWeaver (This Work)
Granularity Task-level Instance-level
LoRA Generation One set of weights per task Unique weights for each input
Conditioning Task description/identifier Semantic embedding of input prompt
Generalization Relies on good task descriptions Generalizes to any input based on content
Flexibility Task switching Continuous adaptation per instance

Why Instance-Level?

Instance-level adaptation offers several advantages:

  1. Finer-grained optimization: Each input gets custom weights optimized for its specific semantic content
  2. No task boundaries: Works seamlessly across task distributions without explicit task identification
  3. Better generalization potential: Learns to adapt based on semantic patterns rather than task labels
  4. Handles task diversity: Can adapt to inputs that span multiple tasks or novel task combinations

By training on diverse datasets, TaskWeaver learns a universal adaptation function that maps any input's semantic content to optimal LoRA weights.

Installation

Prerequisites

  • Python 3.12
  • CUDA-capable GPU (recommended)
  • Conda or Miniconda

Setup

  1. Clone the repository:
git clone https://github.com/dkapur17/taskweaver.git
cd taskweaver
  1. Create the conda environment:
conda env create -f environment.yml
conda activate taskweaver
  1. Create output directories (if not already present):
mkdir -p _models/{hypernet,lora}
mkdir -p _results/{hypernet,lora,base}

Our models and results can be found here.

Project Structure

taskweaver/
├── src/
│   ├── hypernet/              # TaskWeaver hypernetwork implementation
│   │   ├── taskweaver.py      # Core TaskWeaver module
│   │   ├── dynamic_lora.py    # DynamicLoraLinear layer
│   │   └── collator.py        # Data collator with prompt length tracking
│   ├── lora/                  # LoRA baseline implementation
│   │   └── lora_finetuner.py  # Static LoRA training wrapper
│   ├── dsconf/                # Dataset configuration
│   │   └── dataset_configs.py # Dataset registry and processing
│   ├── eval/                  # Evaluation framework
│   │   ├── evaluator.py       # Main evaluator class
│   │   ├── task.py            # Task abstraction
│   │   └── eval_configs.py    # Evaluation configurations
│   ├── analysis/              # Analysis notebooks
│   ├── scripts/               # Training and evaluation scripts
│   ├── train_hypernet.py      # TaskWeaver training script
│   ├── train_lora.py          # LoRA training script
│   ├── evaluate.py            # Unified evaluation script
│   ├── _models/               # Trained models (gitignored)
│   └── _results/              # Evaluation results (gitignored)
└── environment.yml            # Conda environment specification

Usage

Training TaskWeaver

Train TaskWeaver on all datasets:

python src/train_hypernet.py \
  --model EleutherAI/pythia-70m \
  --datasets all \
  --ignore_datasets ChilleD/SVAMP \
  --hypernet.target_modules query_key_value \
  --hypernet.hidden_dim 1024 \
  --hypernet.lora_rank 2 \
  --hypernet.lora_alpha 8

Train on specific diverse datasets:

python src/train_hypernet.py \
  --model google/gemma-3-270m-it \
  --datasets gsm8k cais/mmlu \
  --hypernet.target_modules q_proj v_proj \
  --hypernet.hidden_dim 256

Exclude specific datasets from "all":

python src/train_hypernet.py \
  --model Qwen/Qwen3-0.6B \
  --datasets all \
  --ignore_datasets tau/commonsense_qa ChilleD/SVAMP

Training LoRA Baselines

Train task-specific LoRA adapters:

python src/train_lora.py \
  --model EleutherAI/pythia-70m \
  --datasets gsm8k \
  --lora.rank 2 \
  --lora.alpha 8 \
  --lora.target_modules query_key_value

Train on mixed datasets:

python src/train_lora.py \
  --model google/gemma-3-270m-it \
  --datasets mix \
  --lora.rank 2

Evaluation

Evaluate any model:

python src/evaluate.py \
  --model_path _models/hypernet/EleutherAI_pythia-70m/all_d1024_r2_a8 \
  --model_type hypernet \
  --datasets all \
  --split test

Batch evaluation using scripts:

# Evaluate all base models
bash src/scripts/eval_base_all.sh

# Evaluate all LoRA models
bash src/scripts/eval_lora_all.sh

# Evaluate all TaskWeaver models
bash src/scripts/eval_hypernet_all.sh

Models used in this work

Model Parameters Target Modules LoRA Rank LoRA Alpha
EleutherAI/pythia-70m 70M query_key_value 2 8
google/gemma-3-270m-it 270M q_proj, v_proj 2 8
Qwen/Qwen3-0.6B 600M q_proj, v_proj 2 8

Supported Datasets

Training & Evaluation Datasets

  • gsm8k: Grade school math word problems
  • cais/mmlu (ARC-Challenge, ARC-Easy): Multiple choice QA
  • google/boolq: Yes/No question answering
  • stanfordnlp/snli: Natural language inference
  • allenai/winogrande: Pronoun resolution
  • Rowan/hellaswag: Common sense reasoning
  • allenai/openbookqa: Knowledge-based QA

Held-Out Test Datasets

These datasets are excluded during training to evaluate generalization:

  • tau/commonsense_qa: Common sense question answering
  • ChilleD/SVAMP: Simple math word problems
  • ehovy/race.middle: Reading comprehension

More datasets can be added by creating the relevant configurations in src/dsconf/dataset_configs.py.

Experimental Results

The project includes comprehensive evaluation comparing three approaches:

  1. Base Models: Unmodified pretrained models (no adaptation)
  2. LoRA Models: Static task-specific LoRA adapters (task-level adaptation)
  3. LoRA Mix: A single LoRA adapter trained on the same collection of datasets as TaskWeaver (static for all inputs)
  4. TaskWeaver: Dynamic instance-level hypernetwork-generated adaptations

Generalization Testing

TaskWeaver is evaluated on both:

  • Training datasets: Datasets seen during hypernetwork training
  • Held-out datasets: Unseen datasets to test generalization of the instance-level adaptation strategy

This tests whether the hypernetwork learns a universal adaptation function that generalizes beyond the training distribution.

ARC Challenge ARC Easy BoolQ GMS8K HELLASWAG OpenBookQA SNLI Winogrande (M) SVAMP CommonSenseQA RACE (Middle)
EleutherAI/pythia-70m, Base 9.13% 7.07% 14.2% 1.29% 4.94% 8.2% 0.519% 6.08% 2.33% 3.44% 8.64%
EleutherAI/pythia-70m, LoRA (Individual) 22.8% 23.7% 53.5% 1.14% 25.4% 22.4% 34.9% 49.7% - - -
EleutherAI/pythia-70m, LoRA (Mixed) 24.1% 22.3% 53.9% 1.97% 24.6% 26.6% 31.0% 44.3% 3.0% 18.3% 18.7%
EleutherAI/pythia-70m, TaskWeaver (Ours) 24.9% 23.9% 54.5% 1.67% 24.8% 28.6% 29.5% 48.4% 1.0% 19.7% 25.1%
b google/gemma-3-270m-it, Base 18.9% 23.0% 43.5% 4.55% 24.7% 24.6% 34.3% 50.6% 19.3% 18.4%
google/gemma-3-270m-it, LoRA (Individual) 25.6% 25.0% 56.4% 3.64% 24.2% 26.4% 42.9% 53.8% - - -
google/gemma-3-270m-it, LoRA (Mixed) 25.3% 21.3% 52.0% 3.41% 23.5% 25.2% 33.3% 50.9% 4.67% 20.6% 22.8%
google/gemma-3-270m-it, TaskWeaver (Ours) 21.1% 23.7% 45.7% 5.84% 24.5% 27.6% 34.7% 47.8% 16.3% 21.5% 23.0%
Qwen/Qwen3-0.6B, Base 21.7% 31.7% 63.6% 50.0% 25.9% 34.2% 42.4% 50.4% 73.0% 38.2% 34.4%
Qwen/Qwen3-0.6B, LoRA (Individual) 50.7% 71.8% 78.8% 35.9% 50.7% 62.2% 84.2% 48.7% - - -
Qwen/Qwen3-0.6B, LoRA (Mixed) 54.4% 68.8% 68.6% 34.5% 33.0% 53.4% 76.9% 48.5% 52.7% 50.9% 64.8%
Qwen/Qwen3-0.6B, TaskWeaver (Ours) 35.2% 48.5% 64.1% 52.5% 29.0% 38.8% 49.2% 52.1% 74.7% 45.9% 44.2%

Results Storage

Results are saved in _results/ with the following structure:

_results/
├── base/{model_name}/
├── lora/{model_name}/{dataset}/
└── hypernet/{model_name}/{config}/

Each result file includes:

  • Per-task accuracy/metrics
  • Aggregated performance statistics
  • Model metadata and configuration
  • Timestamp and evaluation parameters

Implementation Details

Hypernetwork Architecture

The TaskWeaver hypernetwork processes each input independently to generate instance-specific LoRA weights:

  1. Semantic Embedding Extraction: Extract hidden state from last prompt token of each input
  2. Projection: Project to hypernetwork embedding space (hidden_dim)
  3. Contextual Conditioning: Add layer, module, and matrix embeddings for structural context
  4. MLP Processing: 2-layer MLP with GELU and LayerNorm transforms the semantic embedding
  5. Weight Generation: Linear heads produce unique A and B matrices for each layer and module

Key insight: The same input always produces the same LoRA weights, but different inputs get different weights based on their semantic content. This enables the hypernetwork to learn input-to-adaptation mappings across diverse tasks.

Training Configuration

Default hyperparameters:

  • Learning rate: 1e-6
  • Batch size: 2 per device
  • Gradient accumulation: 2 steps
  • Epochs: 3
  • Warmup ratio: 0.1
  • Optimizer: AdamW (via HuggingFace Trainer)

Key Design Decisions

  1. Base Model Freezing: Only hypernetwork parameters are trained, preserving base model knowledge
  2. Instance-Level Conditioning: Uses semantic embedding from each input's last prompt token
  3. Diverse Multi-Dataset Training: Trains on mixed datasets to learn generalizable adaptation strategies
  4. Unique Weights Per Input: Each input generates its own LoRA weights based on semantic content
  5. Single GPU Training: Dynamic LoRA injection doesn't support DataParallel
  6. Zero-Initialization: Output heads initialized to produce small initial adaptations, which mirrors the initialization strategy for LoRA adapters, allowing gradual learning
  7. Reuse Base Model: The base model is also used as the semantic embedder for the input prompts, as part of the hypernetwork.

Evaluation Methodology

  • Pass@K Support: Multiple generations per sample with majority voting
  • Task-Specific Configs: Custom accuracy metrics per dataset type
  • Reproducibility: Seeded evaluation with configurable parameters
  • Comprehensive Logging: Detailed metadata and configuration tracking

Citation

If you use this code in your research, please cite:

@misc{taskweaver2025,
  title={TaskWeaver: Dynamic Instance-Level Language Model Adaptation via Hypernetworks},
  author={Dhruv Kapur, Andrews George Varghese, Raj Maheshwari},
  year={2025},
  note={Course project for CMU 10-623: Generative AI}
}

License

This project is for educational purposes as part of CMU's 10-623 course.

References

See the refs/ directory for related work, including:

  • Text-to-LoRA (Sakana AI): Task-level hypernetwork approach that inspired this work

Acknowledgments

  • Course: CMU 10-623 Generative AI (Fall 2025)
  • Built using HuggingFace Transformers, PEFT, and TRL libraries
  • Inspired by Text-to-LoRA (Sakana AI) and other hypernetwork research
  • Extended to instance-level adaptation for finer-grained optimization

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •