Course Project for CMU 10-623: Generative AI (Fall 2025)
TaskWeaver is a research implementation exploring hypernetwork-based approaches for instance-level dynamic language model adaptation. Unlike task-level approaches (e.g., Sakana AI's Text-to-LoRA) that generate fixed LoRA weights per task, TaskWeaver generates unique LoRA weights for each individual input on-the-fly based on the input prompt's semantic embedding. By training on diverse datasets, the hypernetwork learns to generalize and optimize model weights to maximize performance for any given input.
- Overview
- Key Features
- Architecture
- Comparison to Text-to-LoRA
- Installation
- Project Structure
- Usage
- Supported Models
- Supported Datasets
- Experimental Results
- Implementation Details
TaskWeaver addresses the challenge of efficient adaptation in language models by going beyond task-level optimization. Existing approaches include:
- Static models: Use a single model for all inputs (no adaptation)
- Task-specific LoRA adapters: Train separate adapters for each task (storage overhead, requires task identification)
- Task-level hypernetworks (e.g., Text-to-LoRA): Generate one set of LoRA weights per task
TaskWeaver proposes a more granular approach: learn a hypernetwork that generates unique LoRA weights for each individual input based on its semantic content. This enables:
- Instance-level adaptation: Every input receives custom LoRA weights optimized for that specific prompt
- Powerful generalization: Trained on diverse datasets to handle any input distribution
- Parameter efficiency: Single hypernetwork serves all tasks and inputs
- No task identification needed: Adaptation is purely based on semantic content of the input prompt, not task labels
- Instance-Level Adaptation: Generates unique LoRA weights for every individual input, not just per-task
- Semantic Conditioning: Uses each prompt's semantic embedding to create input-specific adaptations
- Beyond Text-to-LoRA: More granular than task-level approaches, enabling finer-grained optimization
- Multi-Dataset Training: Trains on diverse dataset mixtures to learn generalizable adaptation strategies
- Flexible Architecture: Works with ANY transformer based language model (our experiments include Pythia, Gemma 3 and Qwen 3), but can be adapted to work with anything that LoRA can be applied to.
- Comprehensive Evaluation: Systematic comparison across base models, static LoRA, and TaskWeaver
- Modular Design: Clean abstractions for datasets, models, and evaluation
- Analysis: The generated LoRA weights are analyzed using T-SNE
TaskWeaver consists of three main components:
Custom linear layer that supports batch-specific LoRA adaptations:
output = Wx + b + (α/r) * BAx
where B and A are instance-specific low-rank matrices generated by the hypernetwork.
Generates instance-specific LoRA weights for each input using:
- Semantic Embedding Extraction: Extracts hidden state from the last prompt token for each input
- Semantic Projection: Projects LM hidden states to hypernetwork embedding space
- Contextual Embeddings: Layer, module, and matrix type embeddings for structural context
- MLP: 2-layer network with GELU activations and LayerNorm
- Output Heads: Separate linear heads for A and B matrices of each target module
This architecture ensures every input gets its own unique set of LoRA weights based on its semantic content.
- Freezes base language model parameters
- Trains only the hypernetwork components
- Uses prompt length tracking to extract semantic embeddings from the last prompt token of each instance
- Trains on diverse, mixed datasets to learn generalizable adaptation strategies
- Supports configurable dataset interleaving and stopping strategies
TaskWeaver builds upon and extends the hypernetwork-for-LoRA paradigm introduced by Sakana AI's Text-to-LoRA work (see refs/ directory). Key differences:
| Aspect | Text-to-LoRA (Sakana AI) | TaskWeaver (This Work) |
|---|---|---|
| Granularity | Task-level | Instance-level |
| LoRA Generation | One set of weights per task | Unique weights for each input |
| Conditioning | Task description/identifier | Semantic embedding of input prompt |
| Generalization | Relies on good task descriptions | Generalizes to any input based on content |
| Flexibility | Task switching | Continuous adaptation per instance |
Instance-level adaptation offers several advantages:
- Finer-grained optimization: Each input gets custom weights optimized for its specific semantic content
- No task boundaries: Works seamlessly across task distributions without explicit task identification
- Better generalization potential: Learns to adapt based on semantic patterns rather than task labels
- Handles task diversity: Can adapt to inputs that span multiple tasks or novel task combinations
By training on diverse datasets, TaskWeaver learns a universal adaptation function that maps any input's semantic content to optimal LoRA weights.
- Python 3.12
- CUDA-capable GPU (recommended)
- Conda or Miniconda
- Clone the repository:
git clone https://github.com/dkapur17/taskweaver.git
cd taskweaver- Create the conda environment:
conda env create -f environment.yml
conda activate taskweaver- Create output directories (if not already present):
mkdir -p _models/{hypernet,lora}
mkdir -p _results/{hypernet,lora,base}Our models and results can be found here.
taskweaver/
├── src/
│ ├── hypernet/ # TaskWeaver hypernetwork implementation
│ │ ├── taskweaver.py # Core TaskWeaver module
│ │ ├── dynamic_lora.py # DynamicLoraLinear layer
│ │ └── collator.py # Data collator with prompt length tracking
│ ├── lora/ # LoRA baseline implementation
│ │ └── lora_finetuner.py # Static LoRA training wrapper
│ ├── dsconf/ # Dataset configuration
│ │ └── dataset_configs.py # Dataset registry and processing
│ ├── eval/ # Evaluation framework
│ │ ├── evaluator.py # Main evaluator class
│ │ ├── task.py # Task abstraction
│ │ └── eval_configs.py # Evaluation configurations
│ ├── analysis/ # Analysis notebooks
│ ├── scripts/ # Training and evaluation scripts
│ ├── train_hypernet.py # TaskWeaver training script
│ ├── train_lora.py # LoRA training script
│ ├── evaluate.py # Unified evaluation script
│ ├── _models/ # Trained models (gitignored)
│ └── _results/ # Evaluation results (gitignored)
└── environment.yml # Conda environment specification
Train TaskWeaver on all datasets:
python src/train_hypernet.py \
--model EleutherAI/pythia-70m \
--datasets all \
--ignore_datasets ChilleD/SVAMP \
--hypernet.target_modules query_key_value \
--hypernet.hidden_dim 1024 \
--hypernet.lora_rank 2 \
--hypernet.lora_alpha 8Train on specific diverse datasets:
python src/train_hypernet.py \
--model google/gemma-3-270m-it \
--datasets gsm8k cais/mmlu \
--hypernet.target_modules q_proj v_proj \
--hypernet.hidden_dim 256Exclude specific datasets from "all":
python src/train_hypernet.py \
--model Qwen/Qwen3-0.6B \
--datasets all \
--ignore_datasets tau/commonsense_qa ChilleD/SVAMPTrain task-specific LoRA adapters:
python src/train_lora.py \
--model EleutherAI/pythia-70m \
--datasets gsm8k \
--lora.rank 2 \
--lora.alpha 8 \
--lora.target_modules query_key_valueTrain on mixed datasets:
python src/train_lora.py \
--model google/gemma-3-270m-it \
--datasets mix \
--lora.rank 2Evaluate any model:
python src/evaluate.py \
--model_path _models/hypernet/EleutherAI_pythia-70m/all_d1024_r2_a8 \
--model_type hypernet \
--datasets all \
--split testBatch evaluation using scripts:
# Evaluate all base models
bash src/scripts/eval_base_all.sh
# Evaluate all LoRA models
bash src/scripts/eval_lora_all.sh
# Evaluate all TaskWeaver models
bash src/scripts/eval_hypernet_all.sh| Model | Parameters | Target Modules | LoRA Rank | LoRA Alpha |
|---|---|---|---|---|
| EleutherAI/pythia-70m | 70M | query_key_value | 2 | 8 |
| google/gemma-3-270m-it | 270M | q_proj, v_proj | 2 | 8 |
| Qwen/Qwen3-0.6B | 600M | q_proj, v_proj | 2 | 8 |
- gsm8k: Grade school math word problems
- cais/mmlu (ARC-Challenge, ARC-Easy): Multiple choice QA
- google/boolq: Yes/No question answering
- stanfordnlp/snli: Natural language inference
- allenai/winogrande: Pronoun resolution
- Rowan/hellaswag: Common sense reasoning
- allenai/openbookqa: Knowledge-based QA
These datasets are excluded during training to evaluate generalization:
- tau/commonsense_qa: Common sense question answering
- ChilleD/SVAMP: Simple math word problems
- ehovy/race.middle: Reading comprehension
More datasets can be added by creating the relevant configurations in src/dsconf/dataset_configs.py.
The project includes comprehensive evaluation comparing three approaches:
- Base Models: Unmodified pretrained models (no adaptation)
- LoRA Models: Static task-specific LoRA adapters (task-level adaptation)
- LoRA Mix: A single LoRA adapter trained on the same collection of datasets as TaskWeaver (static for all inputs)
- TaskWeaver: Dynamic instance-level hypernetwork-generated adaptations
TaskWeaver is evaluated on both:
- Training datasets: Datasets seen during hypernetwork training
- Held-out datasets: Unseen datasets to test generalization of the instance-level adaptation strategy
This tests whether the hypernetwork learns a universal adaptation function that generalizes beyond the training distribution.
| ARC Challenge | ARC Easy | BoolQ | GMS8K | HELLASWAG | OpenBookQA | SNLI | Winogrande (M) | SVAMP | CommonSenseQA | RACE (Middle) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| EleutherAI/pythia-70m, Base | 9.13% | 7.07% | 14.2% | 1.29% | 4.94% | 8.2% | 0.519% | 6.08% | 2.33% | 3.44% | 8.64% |
| EleutherAI/pythia-70m, LoRA (Individual) | 22.8% | 23.7% | 53.5% | 1.14% | 25.4% | 22.4% | 34.9% | 49.7% | - | - | - |
| EleutherAI/pythia-70m, LoRA (Mixed) | 24.1% | 22.3% | 53.9% | 1.97% | 24.6% | 26.6% | 31.0% | 44.3% | 3.0% | 18.3% | 18.7% |
| EleutherAI/pythia-70m, TaskWeaver (Ours) | 24.9% | 23.9% | 54.5% | 1.67% | 24.8% | 28.6% | 29.5% | 48.4% | 1.0% | 19.7% | 25.1% |
| b | google/gemma-3-270m-it, Base | 18.9% | 23.0% | 43.5% | 4.55% | 24.7% | 24.6% | 34.3% | 50.6% | 19.3% | 18.4% |
| google/gemma-3-270m-it, LoRA (Individual) | 25.6% | 25.0% | 56.4% | 3.64% | 24.2% | 26.4% | 42.9% | 53.8% | - | - | - |
| google/gemma-3-270m-it, LoRA (Mixed) | 25.3% | 21.3% | 52.0% | 3.41% | 23.5% | 25.2% | 33.3% | 50.9% | 4.67% | 20.6% | 22.8% |
| google/gemma-3-270m-it, TaskWeaver (Ours) | 21.1% | 23.7% | 45.7% | 5.84% | 24.5% | 27.6% | 34.7% | 47.8% | 16.3% | 21.5% | 23.0% |
| Qwen/Qwen3-0.6B, Base | 21.7% | 31.7% | 63.6% | 50.0% | 25.9% | 34.2% | 42.4% | 50.4% | 73.0% | 38.2% | 34.4% |
| Qwen/Qwen3-0.6B, LoRA (Individual) | 50.7% | 71.8% | 78.8% | 35.9% | 50.7% | 62.2% | 84.2% | 48.7% | - | - | - |
| Qwen/Qwen3-0.6B, LoRA (Mixed) | 54.4% | 68.8% | 68.6% | 34.5% | 33.0% | 53.4% | 76.9% | 48.5% | 52.7% | 50.9% | 64.8% |
| Qwen/Qwen3-0.6B, TaskWeaver (Ours) | 35.2% | 48.5% | 64.1% | 52.5% | 29.0% | 38.8% | 49.2% | 52.1% | 74.7% | 45.9% | 44.2% |
Results are saved in _results/ with the following structure:
_results/
├── base/{model_name}/
├── lora/{model_name}/{dataset}/
└── hypernet/{model_name}/{config}/
Each result file includes:
- Per-task accuracy/metrics
- Aggregated performance statistics
- Model metadata and configuration
- Timestamp and evaluation parameters
The TaskWeaver hypernetwork processes each input independently to generate instance-specific LoRA weights:
- Semantic Embedding Extraction: Extract hidden state from last prompt token of each input
- Projection: Project to hypernetwork embedding space (hidden_dim)
- Contextual Conditioning: Add layer, module, and matrix embeddings for structural context
- MLP Processing: 2-layer MLP with GELU and LayerNorm transforms the semantic embedding
- Weight Generation: Linear heads produce unique A and B matrices for each layer and module
Key insight: The same input always produces the same LoRA weights, but different inputs get different weights based on their semantic content. This enables the hypernetwork to learn input-to-adaptation mappings across diverse tasks.
Default hyperparameters:
- Learning rate: 1e-6
- Batch size: 2 per device
- Gradient accumulation: 2 steps
- Epochs: 3
- Warmup ratio: 0.1
- Optimizer: AdamW (via HuggingFace Trainer)
- Base Model Freezing: Only hypernetwork parameters are trained, preserving base model knowledge
- Instance-Level Conditioning: Uses semantic embedding from each input's last prompt token
- Diverse Multi-Dataset Training: Trains on mixed datasets to learn generalizable adaptation strategies
- Unique Weights Per Input: Each input generates its own LoRA weights based on semantic content
- Single GPU Training: Dynamic LoRA injection doesn't support DataParallel
- Zero-Initialization: Output heads initialized to produce small initial adaptations, which mirrors the initialization strategy for LoRA adapters, allowing gradual learning
- Reuse Base Model: The base model is also used as the semantic embedder for the input prompts, as part of the hypernetwork.
- Pass@K Support: Multiple generations per sample with majority voting
- Task-Specific Configs: Custom accuracy metrics per dataset type
- Reproducibility: Seeded evaluation with configurable parameters
- Comprehensive Logging: Detailed metadata and configuration tracking
If you use this code in your research, please cite:
@misc{taskweaver2025,
title={TaskWeaver: Dynamic Instance-Level Language Model Adaptation via Hypernetworks},
author={Dhruv Kapur, Andrews George Varghese, Raj Maheshwari},
year={2025},
note={Course project for CMU 10-623: Generative AI}
}This project is for educational purposes as part of CMU's 10-623 course.
See the refs/ directory for related work, including:
- Text-to-LoRA (Sakana AI): Task-level hypernetwork approach that inspired this work
- Course: CMU 10-623 Generative AI (Fall 2025)
- Built using HuggingFace Transformers, PEFT, and TRL libraries
- Inspired by Text-to-LoRA (Sakana AI) and other hypernetwork research
- Extended to instance-level adaptation for finer-grained optimization