Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support saving LoRA layers #13

Merged
merged 4 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,10 @@ transformers have been converted:

To use a LoRA transformer, simply replace the model from `candle-transformers` with its counterpart in `candle-lora-transformers`!

## Saving and loading
`candle_lora` supports retrieving weights for LoRA adapters via the `get_tensors` method, defined automatically in `#[auto_layer_convert]`. This function is meant to be used with `candle_core::safetensors::save()`. To load, simply load the `VarBuilder` and pass that to `get_lora_model`.

`candle_lora`'s weight naming is not compatible with `peft` yet.

## Resources
`candle-lora`'s LoRA conversion implementations are based on HuggingFace's [`peft`](https://github.com/huggingface/peft/tree/main) library. See the original paper [here](https://arxiv.org/pdf/2106.09685.pdf), as well as Microsoft's [implementation](https://github.com/microsoft/LoRA).
20 changes: 10 additions & 10 deletions candle-lora-macro/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ This library makes using [`candle-lora`](https://github.com/EricLBuehler/candle-
`candle-lora-macro` exports 2 macros: `AutoLoraConvert` and `replace_layer_fields`.

The `AutoLoraConvert` derive macro automatically creates a method `get_lora_model`, when called which selects and swaps all supported layers for their LoRA counterparts. This method is the equivalent of `peft`'s `get_peft_model` method, and modifies the model in place. It expects all
layers of the supported types to be a `dyn` type, that is `Box<dyn ...LayerLike>`.
layers of the supported types to be a `dyn` type: `Arc<dyn ...LayerLike>`. **Therefore the type wrapping the layer must be `Arc`.**

In addition, `AutoLoraConvert` also defines a method `get_merged_lora_model` which does everything `get_lora_model` does, but also merges the weights of the LoRA layers to improve inference performance.

To further automate the process of using `candle-lora`, `candle-lora-macro` also provides an attribute macro called `replace_layer_fields`.
`replace_layer_fields` swaps out the concrete types for `dyn` types. If this macro is not added to the model structs, be sure to change the member types to `Box<dyn ...LayerLike>`.
`replace_layer_fields` swaps out the concrete types for `dyn` types. If this macro is not added to the model structs, be sure to change the member types to `Arc<dyn ...LayerLike>`.

`replace_layer_fields` is able to swap:
- `Linear` to `Box<dyn LinearLayerLike>`
- `Conv1d` to `Box<dyn Conv1dLayerLike>`
- `Conv2d` to `Box<dyn Conv2dLayerLike>`
- `Embedding` to `Box<dyn EmbeddigLayerLike>`
- `Option<Linear>` to `Option<Box<dyn LinearLayerLike>>`
- `Option<Conv1d>` to `Option<Box<dyn Conv1dLayerLike>>`
- `Option<Conv2d>` to `Option<Box<dyn Conv2dLayerLike>>`
- `Option<Embedding>` to `Option<Box<dyn EmbeddigLayerLike>>`
- `Linear` to `Arc<dyn LinearLayerLike>`
- `Conv1d` to `Arc<dyn Conv1dLayerLike>`
- `Conv2d` to `Arc<dyn Conv2dLayerLike>`
- `Embedding` to `Arc<dyn EmbeddigLayerLike>`
- `Option<Linear>` to `Option<Arc<dyn LinearLayerLike>>`
- `Option<Conv1d>` to `Option<Arc<dyn Conv1dLayerLike>>`
- `Option<Conv2d>` to `Option<Arc<dyn Conv2dLayerLike>>`
- `Option<Embedding>` to `Option<Arc<dyn EmbeddigLayerLike>>`
8 changes: 6 additions & 2 deletions candle-lora-macro/examples/linear.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use candle_core::{DType, Device, Module, Result, Tensor};
use candle_lora::{LinearLayerLike, LoraConfig, LoraLinearConfig};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
Expand All @@ -6,7 +8,7 @@ use candle_nn::{init, Linear, VarBuilder, VarMap};
#[replace_layer_fields]
#[derive(AutoLoraConvert, Debug)]
struct Model {
a: Box<dyn LinearLayerLike>,
a: Arc<dyn LinearLayerLike>,
b: i32,
}

Expand All @@ -32,7 +34,7 @@ fn main() {
.unwrap();

let mut model = Model {
a: Box::new(Linear::new(layer_weight.clone(), None)),
a: Arc::new(Linear::new(layer_weight.clone(), None)),
b: 1,
};

Expand All @@ -49,6 +51,8 @@ fn main() {
None,
);

dbg!(model.get_tensors());

let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap();

//Test the model
Expand Down
241 changes: 146 additions & 95 deletions candle-lora-macro/src/lib.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion candle-lora-transformers/src/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ struct BertEmbedding {
}

impl Deref for BertEmbedding {
type Target = Arc<dyn EmbeddingLayerLike + Send + Sync>;
type Target = Arc<dyn EmbeddingLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/bigcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
}

impl Deref for CustomLinear {
type Target = Arc<dyn LinearLayerLike + Send + Sync>;
type Target = Arc<dyn LinearLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand All @@ -30,7 +30,7 @@
}

impl Deref for CustomEmbedding {
type Target = Arc<dyn EmbeddingLayerLike + Send + Sync>;
type Target = Arc<dyn EmbeddingLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down Expand Up @@ -497,7 +497,7 @@
let seq_len_dim = if self.config.multi_query { 2 } else { 1 };
let attention_mask = attention_mask.unsqueeze(seq_len_dim)?;

let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?;

Check warning on line 500 in candle-lora-transformers/src/bigcode.rs

View workflow job for this annotation

GitHub Actions / Typos

"arange" should be "arrange".
let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?;
let input_embeds = self.wte.forward(input_ids)?;
let position_embeds = self.wpe.forward(&position_ids)?;
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/dinov2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct DinoLinear {
}

impl Deref for DinoLinear {
type Target = Arc<dyn LinearLayerLike + Send + Sync>;
type Target = Arc<dyn LinearLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down Expand Up @@ -290,7 +290,7 @@ struct DinoConv2d {
}

impl Deref for DinoConv2d {
type Target = Arc<dyn Conv2dLayerLike + Send + Sync>;
type Target = Arc<dyn Conv2dLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
}
_ => {}
}
let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?;

Check warning on line 176 in candle-lora-transformers/src/falcon.rs

View workflow job for this annotation

GitHub Actions / Typos

"arange" should be "arrange".
let inv_freq = self.inv_freq.to_dtype(dtype)?;
let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Expand Down Expand Up @@ -219,15 +219,15 @@
}

impl Deref for AttentionQKV {
type Target = Arc<dyn LinearLayerLike + Send + Sync>;
type Target = Arc<dyn LinearLayerLike>;

fn deref(&self) -> &Self::Target {
&self.query_key_value
}
}

impl Deref for AttentionDense {
type Target = Arc<dyn LinearLayerLike + Send + Sync>;
type Target = Arc<dyn LinearLayerLike>;

fn deref(&self) -> &Self::Target {
&self.dense
Expand Down
7 changes: 7 additions & 0 deletions candle-lora-transformers/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_lora::{
EmbeddingLayerLike, LinearLayerLike, LoraConfig, LoraEmbeddingConfig, LoraLinearConfig,
Saveable,
};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
use candle_nn::{Embedding, Module, VarBuilder};
Expand Down Expand Up @@ -103,6 +104,12 @@
}
}

impl Saveable for LlamaLinear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!()
}
}

impl LinearLayerLike for LlamaLinear {
fn bias(&self) -> Option<&Tensor> {
self.inner.bias()
Expand Down Expand Up @@ -135,7 +142,7 @@
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?

Check warning on line 145 in candle-lora-transformers/src/llama.rs

View workflow job for this annotation

GitHub Actions / Typos

"arange" should be "arrange".
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-lora-transformers/src/resnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! See "Deep Residual Learning for Image Recognition" He et al. 2015
//! <https://arxiv.org/abs/1512.03385>

use candle_core::{Module, Result, D};
use candle_core::{Module, Result, Tensor, D};
use candle_lora::{Conv2dLayerLike, LoraConfig, LoraConv2dConfig};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
use candle_nn::{batch_norm, VarBuilder};
Expand Down
1 change: 0 additions & 1 deletion candle-lora/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ candle-core.workspace = true
candle-nn.workspace = true
either.workspace = true
thiserror.workspace = true
trc.workspace = true

[features]
cuda = ["candle-core/cuda", "candle-nn/cuda"]
16 changes: 15 additions & 1 deletion candle-lora/src/frozenconv.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Module, Result, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};

use crate::{Conv1dLayerLike, Conv2dLayerLike};
use crate::{Conv1dLayerLike, Conv2dLayerLike, Saveable};

/// Conv1d, but with a `new` implementation that ensures the weights are detached (frozen).
#[derive(Debug)]
Expand Down Expand Up @@ -42,6 +44,12 @@ impl Module for FrozenConv1d {
}
}

impl Saveable for FrozenConv1d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl Conv1dLayerLike for FrozenConv1d {
fn config(&self) -> &Conv1dConfig {
self.conv.config()
Expand Down Expand Up @@ -93,6 +101,12 @@ impl Module for FrozenConv2d {
}
}

impl Saveable for FrozenConv2d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl Conv2dLayerLike for FrozenConv2d {
fn config(&self) -> &Conv2dConfig {
self.conv.config()
Expand Down
10 changes: 9 additions & 1 deletion candle-lora/src/frozenembed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Result, Tensor};
use candle_nn::Embedding;

use crate::EmbeddingLayerLike;
use crate::{EmbeddingLayerLike, Saveable};

/// Embedding, but with a `new` implementation that ensures the embeddings are detached (frozen).
#[derive(Debug)]
Expand All @@ -27,6 +29,12 @@ impl crate::Module for FrozenEmbedding {
}
}

impl Saveable for FrozenEmbedding {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl EmbeddingLayerLike for FrozenEmbedding {
fn embeddings(&self) -> &Tensor {
self.embed.embeddings()
Expand Down
10 changes: 9 additions & 1 deletion candle-lora/src/frozenlinear.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Module, Result, Shape, Tensor};
use candle_nn::Linear;

use crate::LinearLayerLike;
use crate::{LinearLayerLike, Saveable};

/// Linear, but with a `new` implementation that ensures the weight and/or biases are detached (frozen).
#[derive(Debug)]
Expand All @@ -27,6 +29,12 @@ impl Module for FrozenLinear {
}
}

impl Saveable for FrozenLinear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl LinearLayerLike for FrozenLinear {
fn bias(&self) -> Option<&Tensor> {
self.linear.bias()
Expand Down
36 changes: 32 additions & 4 deletions candle-lora/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,23 @@ pub struct NewLayers<T: Eq + PartialEq + Hash> {
pub embed: HashMap<T, LoraEmbedding>,
}

pub trait Saveable {
fn get_tensors(&self, accum: &mut HashMap<String, Tensor>);
}

/// Any layer that is linear-like.
pub trait LinearLayerLike: Module + Debug {
pub trait LinearLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn shape(&self) -> &Shape;
}

impl Saveable for Linear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl LinearLayerLike for Linear {
fn weight(&self) -> &Tensor {
self.weight()
Expand All @@ -231,12 +241,18 @@ impl LinearLayerLike for Linear {
}

/// Any layer that is conv1d-like.
pub trait Conv1dLayerLike: Module + Debug {
pub trait Conv1dLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn config(&self) -> &Conv1dConfig;
}

impl Saveable for Conv1d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl Conv1dLayerLike for Conv1d {
fn config(&self) -> &Conv1dConfig {
self.config()
Expand All @@ -250,12 +266,18 @@ impl Conv1dLayerLike for Conv1d {
}

/// Any layer that is conv2d-like.
pub trait Conv2dLayerLike: Module + Debug {
pub trait Conv2dLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn config(&self) -> &Conv2dConfig;
}

impl Saveable for Conv2d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl Conv2dLayerLike for Conv2d {
fn config(&self) -> &Conv2dConfig {
self.config()
Expand All @@ -269,11 +291,17 @@ impl Conv2dLayerLike for Conv2d {
}

/// Any layer that is embedding-like.
pub trait EmbeddingLayerLike: Module + Debug {
pub trait EmbeddingLayerLike: Module + Debug + Saveable + Send + Sync {
fn embeddings(&self) -> &Tensor;
fn hidden_size(&self) -> usize;
}

impl Saveable for Embedding {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl EmbeddingLayerLike for Embedding {
fn embeddings(&self) -> &Tensor {
self.embeddings()
Expand Down
Loading
Loading