Skip to content

Commit

Permalink
Update readme, update impl of conv, linear
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 11, 2023
1 parent 856ea2c commit e0cf62a
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 51 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

LoRA (low rank adaptation) implemented in Rust for use with [`Candle`](https://github.com/huggingface/candle/tree/main).

It is based upon [this implementation](https://github.com/microsoft/LoRA/tree/main), which mirrors what HuggingFace's [`peft`](https://github.com/huggingface/peft/tree/main) library implements. See the original paper [here](https://arxiv.org/pdf/2106.09685.pdf).
It is based on HuggingFace's [`peft`](https://github.com/huggingface/peft/tree/main) library. See the original paper [here](https://arxiv.org/pdf/2106.09685.pdf).

candle-lora is able to convert:
- `Linear` -> `LoraLinear`
- `Conv1d` -> `LoraConv1d`
- `Conv2d` -> `LoraConv2d`

**WIP statement: I plan to add conversion for `Embedding` very shortly!**
- `Embedding` -> `LoraEmbedding`

Current working example:
```rust
Expand Down
35 changes: 15 additions & 20 deletions src/loraconv1d.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::ops::Mul;

use candle_core::{DType, Device, Module, Result, Tensor};
use candle_nn::{init, Conv1dConfig, VarMap};
use candle_nn::{init, Conv1d, Conv1dConfig, Dropout, VarMap};

use crate::{frozenconv::FrozenConv1d, Conv1dLayerLike};

Expand All @@ -11,6 +11,7 @@ pub struct LoraConv1d {
a: Tensor,
b: Tensor,
scale: Option<f64>,
dropout: Option<Dropout>,
}

pub struct LoraConv1dConfig<'a> {
Expand All @@ -21,6 +22,7 @@ pub struct LoraConv1dConfig<'a> {
pub dtype: DType,
pub in_channels: usize,
pub out_channels: usize,
pub dropout: Option<f32>,
}

impl<'a> LoraConv1dConfig<'a> {
Expand All @@ -39,6 +41,7 @@ impl<'a> LoraConv1dConfig<'a> {
dtype,
in_channels,
out_channels,
dropout: Some(0.),
}
}
}
Expand Down Expand Up @@ -76,37 +79,29 @@ impl LoraConv1d {
} else {
None
},
dropout: config.dropout.map(Dropout::new),
})
}
}

impl Module for LoraConv1d {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if let Some(scale) = self.scale {
let x = input;
let bias = self.bias();
let weight = (self.old.weight()
let bias = self.bias().cloned();

let mut weight = self.old.weight().clone();
if self.dropout.is_some() {
weight = self.dropout.as_ref().unwrap().forward(input, true)?;
}
let weight = (&weight
+ &self
.b
.matmul(&self.a)?
.broadcast_matmul(&self.a.broadcast_matmul(&weight)?)?
.reshape(self.old.weight().shape())?
.mul(scale)?)?;

let x = x.conv1d(
&weight,
self.config().padding,
self.config().stride,
self.config().dilation,
self.config().groups,
)?;
match &bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
let conv = Conv1d::new(weight, bias, *self.config());
conv.forward(input)
} else {
self.old.forward(input)
}
Expand Down
58 changes: 31 additions & 27 deletions src/loraconv2d.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::ops::Mul;

use candle_core::{DType, Device, Module, Result, Tensor};
use candle_nn::{init, Conv2dConfig, VarMap};
use candle_nn::{init, Conv2d, Conv2dConfig, Dropout, VarMap};

use crate::{frozenconv::FrozenConv2d, Conv2dLayerLike};

Expand All @@ -11,6 +11,7 @@ pub struct LoraConv2d {
a: Tensor,
b: Tensor,
scale: Option<f64>,
dropout: Option<Dropout>,
}

pub struct LoraConv2dConfig<'a> {
Expand All @@ -19,8 +20,9 @@ pub struct LoraConv2dConfig<'a> {
pub kernel_size: usize,
pub device: &'a Device,
pub dtype: DType,
in_channels: usize,
out_channels: usize,
pub in_channels: usize,
pub out_channels: usize,
pub dropout: Option<f32>,
}

impl<'a> LoraConv2dConfig<'a> {
Expand All @@ -39,6 +41,7 @@ impl<'a> LoraConv2dConfig<'a> {
dtype,
in_channels,
out_channels,
dropout: Some(0.),
}
}
}
Expand All @@ -48,8 +51,10 @@ impl LoraConv2d {
let map = VarMap::new();
let a = map.get(
(
config.rank * config.kernel_size,
config.in_channels * config.kernel_size,
config.rank, // * config.kernel_size,
config.in_channels / old.config().groups,
old.weight().dim(2).unwrap(),
old.weight().dim(3).unwrap(), // * config.kernel_size,
),
"a.weight",
init::DEFAULT_KAIMING_NORMAL,
Expand All @@ -58,8 +63,10 @@ impl LoraConv2d {
)?;
let b = map.get(
(
config.out_channels / old.config().groups * config.kernel_size,
config.rank * config.kernel_size,
config.out_channels, // / old.config().groups * config.kernel_size,
config.rank / old.config().groups,
1,
1, // * config.kernel_size,
),
"b.weight",
init::ZERO,
Expand All @@ -76,36 +83,33 @@ impl LoraConv2d {
} else {
None
},
dropout: config.dropout.map(Dropout::new),
})
}
}

impl Module for LoraConv2d {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if let Some(scale) = self.scale {
let x = input;
let bias = self.bias();
let weight = self.old.forward(input)?;
let mut a_input = input.clone();
if self.dropout.is_some() {
a_input = self.dropout.as_ref().unwrap().forward(input, true)?;
}

let weight = (self.old.weight()
+ self.b.matmul(&self.a)?.reshape(self.old.weight().shape())?)?
.mul(scale)?;
let a_conv = Conv2d::new(self.a.clone(), None, *self.config());
let b_conv = Conv2d::new(
self.b.clone(),
None,
Conv2dConfig {
stride: 1,
..*self.config()
},
);

let x = x.conv2d(
&weight,
self.config().padding,
self.config().stride,
self.config().dilation,
self.config().groups,
)?;
let tmp = b_conv.forward(&a_conv.forward(&a_input)?)?;

match &bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
&weight + tmp.mul(scale)?
} else {
self.old.forward(input)
}
Expand Down
3 changes: 2 additions & 1 deletion tests/conv2d.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fn main() -> candle_core::Result<()> {
#[test]
fn conv2d() -> candle_core::Result<()> {
use std::{collections::HashMap, hash::Hash};

use candle_core::{DType, Device, Result, Tensor};
Expand Down

0 comments on commit e0cf62a

Please sign in to comment.