Skip to content

Commit

Permalink
Merge pull request #13 from EricLBuehler/saving
Browse files Browse the repository at this point in the history
Support saving LoRA layers
  • Loading branch information
EricLBuehler committed Apr 3, 2024
2 parents 7fe56f4 + 675d3d2 commit 84d7c42
Show file tree
Hide file tree
Showing 27 changed files with 385 additions and 178 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,10 @@ transformers have been converted:

To use a LoRA transformer, simply replace the model from `candle-transformers` with its counterpart in `candle-lora-transformers`!

## Saving and loading
`candle_lora` supports retrieving weights for LoRA adapters via the `get_tensors` method, defined automatically in `#[auto_layer_convert]`. This function is meant to be used with `candle_core::safetensors::save()`. To load, simply load the `VarBuilder` and pass that to `get_lora_model`.

`candle_lora`'s weight naming is not compatible with `peft` yet.

## Resources
`candle-lora`'s LoRA conversion implementations are based on HuggingFace's [`peft`](https://github.com/huggingface/peft/tree/main) library. See the original paper [here](https://arxiv.org/pdf/2106.09685.pdf), as well as Microsoft's [implementation](https://github.com/microsoft/LoRA).
20 changes: 10 additions & 10 deletions candle-lora-macro/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ This library makes using [`candle-lora`](https://github.com/EricLBuehler/candle-
`candle-lora-macro` exports 2 macros: `AutoLoraConvert` and `replace_layer_fields`.

The `AutoLoraConvert` derive macro automatically creates a method `get_lora_model`, when called which selects and swaps all supported layers for their LoRA counterparts. This method is the equivalent of `peft`'s `get_peft_model` method, and modifies the model in place. It expects all
layers of the supported types to be a `dyn` type, that is `Box<dyn ...LayerLike>`.
layers of the supported types to be a `dyn` type: `Arc<dyn ...LayerLike>`. **Therefore the type wrapping the layer must be `Arc`.**

In addition, `AutoLoraConvert` also defines a method `get_merged_lora_model` which does everything `get_lora_model` does, but also merges the weights of the LoRA layers to improve inference performance.

To further automate the process of using `candle-lora`, `candle-lora-macro` also provides an attribute macro called `replace_layer_fields`.
`replace_layer_fields` swaps out the concrete types for `dyn` types. If this macro is not added to the model structs, be sure to change the member types to `Box<dyn ...LayerLike>`.
`replace_layer_fields` swaps out the concrete types for `dyn` types. If this macro is not added to the model structs, be sure to change the member types to `Arc<dyn ...LayerLike>`.

`replace_layer_fields` is able to swap:
- `Linear` to `Box<dyn LinearLayerLike>`
- `Conv1d` to `Box<dyn Conv1dLayerLike>`
- `Conv2d` to `Box<dyn Conv2dLayerLike>`
- `Embedding` to `Box<dyn EmbeddigLayerLike>`
- `Option<Linear>` to `Option<Box<dyn LinearLayerLike>>`
- `Option<Conv1d>` to `Option<Box<dyn Conv1dLayerLike>>`
- `Option<Conv2d>` to `Option<Box<dyn Conv2dLayerLike>>`
- `Option<Embedding>` to `Option<Box<dyn EmbeddigLayerLike>>`
- `Linear` to `Arc<dyn LinearLayerLike>`
- `Conv1d` to `Arc<dyn Conv1dLayerLike>`
- `Conv2d` to `Arc<dyn Conv2dLayerLike>`
- `Embedding` to `Arc<dyn EmbeddigLayerLike>`
- `Option<Linear>` to `Option<Arc<dyn LinearLayerLike>>`
- `Option<Conv1d>` to `Option<Arc<dyn Conv1dLayerLike>>`
- `Option<Conv2d>` to `Option<Arc<dyn Conv2dLayerLike>>`
- `Option<Embedding>` to `Option<Arc<dyn EmbeddigLayerLike>>`
8 changes: 6 additions & 2 deletions candle-lora-macro/examples/linear.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use candle_core::{DType, Device, Module, Result, Tensor};
use candle_lora::{LinearLayerLike, LoraConfig, LoraLinearConfig};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
Expand All @@ -6,7 +8,7 @@ use candle_nn::{init, Linear, VarBuilder, VarMap};
#[replace_layer_fields]
#[derive(AutoLoraConvert, Debug)]
struct Model {
a: Box<dyn LinearLayerLike>,
a: Arc<dyn LinearLayerLike>,
b: i32,
}

Expand All @@ -32,7 +34,7 @@ fn main() {
.unwrap();

let mut model = Model {
a: Box::new(Linear::new(layer_weight.clone(), None)),
a: Arc::new(Linear::new(layer_weight.clone(), None)),
b: 1,
};

Expand All @@ -49,6 +51,8 @@ fn main() {
None,
);

dbg!(model.get_tensors());

let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap();

//Test the model
Expand Down
241 changes: 146 additions & 95 deletions candle-lora-macro/src/lib.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion candle-lora-transformers/src/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ struct BertEmbedding {
}

impl Deref for BertEmbedding {
type Target = Arc<dyn EmbeddingLayerLike + Send + Sync>;
type Target = Arc<dyn EmbeddingLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/bigcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct CustomLinear {
}

impl Deref for CustomLinear {
type Target = Arc<dyn LinearLayerLike + Send + Sync>;
type Target = Arc<dyn LinearLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand All @@ -30,7 +30,7 @@ struct CustomEmbedding {
}

impl Deref for CustomEmbedding {
type Target = Arc<dyn EmbeddingLayerLike + Send + Sync>;
type Target = Arc<dyn EmbeddingLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/dinov2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct DinoLinear {
}

impl Deref for DinoLinear {
type Target = Arc<dyn LinearLayerLike + Send + Sync>;
type Target = Arc<dyn LinearLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down Expand Up @@ -290,7 +290,7 @@ struct DinoConv2d {
}

impl Deref for DinoConv2d {
type Target = Arc<dyn Conv2dLayerLike + Send + Sync>;
type Target = Arc<dyn Conv2dLayerLike>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,15 @@ struct AttentionDense {
}

impl Deref for AttentionQKV {
type Target = Arc<dyn LinearLayerLike + Send + Sync>;
type Target = Arc<dyn LinearLayerLike>;

fn deref(&self) -> &Self::Target {
&self.query_key_value
}
}

impl Deref for AttentionDense {
type Target = Arc<dyn LinearLayerLike + Send + Sync>;
type Target = Arc<dyn LinearLayerLike>;

fn deref(&self) -> &Self::Target {
&self.dense
Expand Down
7 changes: 7 additions & 0 deletions candle-lora-transformers/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_lora::{
EmbeddingLayerLike, LinearLayerLike, LoraConfig, LoraEmbeddingConfig, LoraLinearConfig,
Saveable,
};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
use candle_nn::{Embedding, Module, VarBuilder};
Expand Down Expand Up @@ -103,6 +104,12 @@ impl Module for LlamaLinear {
}
}

impl Saveable for LlamaLinear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!()
}
}

impl LinearLayerLike for LlamaLinear {
fn bias(&self) -> Option<&Tensor> {
self.inner.bias()
Expand Down
2 changes: 1 addition & 1 deletion candle-lora-transformers/src/resnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! See "Deep Residual Learning for Image Recognition" He et al. 2015
//! <https://arxiv.org/abs/1512.03385>

use candle_core::{Module, Result, D};
use candle_core::{Module, Result, Tensor, D};
use candle_lora::{Conv2dLayerLike, LoraConfig, LoraConv2dConfig};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
use candle_nn::{batch_norm, VarBuilder};
Expand Down
1 change: 0 additions & 1 deletion candle-lora/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ candle-core.workspace = true
candle-nn.workspace = true
either.workspace = true
thiserror.workspace = true
trc.workspace = true

[features]
cuda = ["candle-core/cuda", "candle-nn/cuda"]
16 changes: 15 additions & 1 deletion candle-lora/src/frozenconv.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Module, Result, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};

use crate::{Conv1dLayerLike, Conv2dLayerLike};
use crate::{Conv1dLayerLike, Conv2dLayerLike, Saveable};

/// Conv1d, but with a `new` implementation that ensures the weights are detached (frozen).
#[derive(Debug)]
Expand Down Expand Up @@ -42,6 +44,12 @@ impl Module for FrozenConv1d {
}
}

impl Saveable for FrozenConv1d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl Conv1dLayerLike for FrozenConv1d {
fn config(&self) -> &Conv1dConfig {
self.conv.config()
Expand Down Expand Up @@ -93,6 +101,12 @@ impl Module for FrozenConv2d {
}
}

impl Saveable for FrozenConv2d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl Conv2dLayerLike for FrozenConv2d {
fn config(&self) -> &Conv2dConfig {
self.conv.config()
Expand Down
10 changes: 9 additions & 1 deletion candle-lora/src/frozenembed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Result, Tensor};
use candle_nn::Embedding;

use crate::EmbeddingLayerLike;
use crate::{EmbeddingLayerLike, Saveable};

/// Embedding, but with a `new` implementation that ensures the embeddings are detached (frozen).
#[derive(Debug)]
Expand All @@ -27,6 +29,12 @@ impl crate::Module for FrozenEmbedding {
}
}

impl Saveable for FrozenEmbedding {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl EmbeddingLayerLike for FrozenEmbedding {
fn embeddings(&self) -> &Tensor {
self.embed.embeddings()
Expand Down
10 changes: 9 additions & 1 deletion candle-lora/src/frozenlinear.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Module, Result, Shape, Tensor};
use candle_nn::Linear;

use crate::LinearLayerLike;
use crate::{LinearLayerLike, Saveable};

/// Linear, but with a `new` implementation that ensures the weight and/or biases are detached (frozen).
#[derive(Debug)]
Expand All @@ -27,6 +29,12 @@ impl Module for FrozenLinear {
}
}

impl Saveable for FrozenLinear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl LinearLayerLike for FrozenLinear {
fn bias(&self) -> Option<&Tensor> {
self.linear.bias()
Expand Down
36 changes: 32 additions & 4 deletions candle-lora/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,23 @@ pub struct NewLayers<T: Eq + PartialEq + Hash> {
pub embed: HashMap<T, LoraEmbedding>,
}

pub trait Saveable {
fn get_tensors(&self, accum: &mut HashMap<String, Tensor>);
}

/// Any layer that is linear-like.
pub trait LinearLayerLike: Module + Debug {
pub trait LinearLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn shape(&self) -> &Shape;
}

impl Saveable for Linear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl LinearLayerLike for Linear {
fn weight(&self) -> &Tensor {
self.weight()
Expand All @@ -231,12 +241,18 @@ impl LinearLayerLike for Linear {
}

/// Any layer that is conv1d-like.
pub trait Conv1dLayerLike: Module + Debug {
pub trait Conv1dLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn config(&self) -> &Conv1dConfig;
}

impl Saveable for Conv1d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl Conv1dLayerLike for Conv1d {
fn config(&self) -> &Conv1dConfig {
self.config()
Expand All @@ -250,12 +266,18 @@ impl Conv1dLayerLike for Conv1d {
}

/// Any layer that is conv2d-like.
pub trait Conv2dLayerLike: Module + Debug {
pub trait Conv2dLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn config(&self) -> &Conv2dConfig;
}

impl Saveable for Conv2d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl Conv2dLayerLike for Conv2d {
fn config(&self) -> &Conv2dConfig {
self.config()
Expand All @@ -269,11 +291,17 @@ impl Conv2dLayerLike for Conv2d {
}

/// Any layer that is embedding-like.
pub trait EmbeddingLayerLike: Module + Debug {
pub trait EmbeddingLayerLike: Module + Debug + Saveable + Send + Sync {
fn embeddings(&self) -> &Tensor;
fn hidden_size(&self) -> usize;
}

impl Saveable for Embedding {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl EmbeddingLayerLike for Embedding {
fn embeddings(&self) -> &Tensor {
self.embeddings()
Expand Down
Loading

0 comments on commit 84d7c42

Please sign in to comment.