Skip to content

Commit

Permalink
Add linear, embed merged
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 12, 2023
1 parent 8622c74 commit f46c270
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 43 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ impl Module for Model {

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, conv) in new.linear {
for (name, linear) in new.linear {
match name {
ModelLayers::Layer => self.layer = Box::new(conv),
ModelLayers::Layer => self.layer = Box::new(linear),
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions src/loraembed.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Mul;

use candle_core::{DType, Device, Module, Result, Tensor};
use candle_nn::{init, Embedding, VarMap};

Expand All @@ -9,6 +11,7 @@ pub struct LoraEmbedding {
a: Tensor,
b: Tensor,
scale: Option<f64>,
merged: bool,
}

/// Configuration for LoraEmbedding, with `num_embeddings` vectors of `embedding_dim` size`.
Expand Down Expand Up @@ -90,8 +93,26 @@ impl LoraEmbedding {
} else {
None
},
merged: false,
})
}

fn get_delta_weight(&self) -> Result<Tensor> {
let result = self.b.matmul(&self.a)?;
Ok(match self.scale {
Some(scale) => result.mul(scale)?,
None => result,
})
}

pub fn merge(&mut self) -> Result<()> {
self.old = FrozenEmbedding::new(
&(self.embeddings() + self.get_delta_weight()?.transpose(0, 1))?,
self.hidden_size(),
)?;
self.merged = true;
Ok(())
}
}

impl Module for LoraEmbedding {
Expand Down
47 changes: 35 additions & 12 deletions src/loralinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct LoraLinear {
b: Tensor,
scale: Option<f64>,
dropout: Option<Dropout>,
merged: bool,
}

/// Configuration for LoraLinear
Expand Down Expand Up @@ -96,26 +97,48 @@ impl LoraLinear {
None
},
dropout: config.dropout.map(Dropout::new),
merged: false,
})
}

fn get_delta_weight(&self) -> Result<Tensor> {
let result = self.b.matmul(&self.a)?;
Ok(match self.scale {
Some(scale) => result.mul(scale)?,
None => result,
})
}

pub fn merge(&mut self) -> Result<()> {
self.old = FrozenLinear::new(
(self.old.weight() + self.get_delta_weight())?,
self.old.bias().cloned(),
)?;
self.merged = true;
Ok(())
}
}

impl Module for LoraLinear {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
//No fan_in_fan_out so no weight.transpose(0,1)
let mut result = self.old.forward(input)?;
if let Some(scale) = self.scale {
if self.dropout.is_some() {
result = (result + self.dropout.as_ref().unwrap().forward(input, true)?)?;
} else {
result = (result + input)?;
if self.merged {
self.old.forward(input)
} else {
//No fan_in_fan_out so no weight.transpose(0,1)
let mut result = self.old.forward(input)?;
if let Some(scale) = self.scale {
if self.dropout.is_some() {
result = (result + self.dropout.as_ref().unwrap().forward(input, true)?)?;
} else {
result = (result + input)?;
}
result = result.broadcast_add(
&result.matmul(&self.b.broadcast_matmul(&self.a.matmul(&result)?)?)?,
)?;
result = result.broadcast_add(&result.clone().mul(scale)?)?;
}
result = result.broadcast_add(
&result.matmul(&self.b.broadcast_matmul(&self.a.matmul(&result)?)?)?,
)?;
result = result.broadcast_add(&result.clone().mul(scale)?)?;
Ok(result)
}
Ok(result)
}
}

Expand Down
53 changes: 28 additions & 25 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
use std::{collections::HashMap, hash::Hash};
use candle_lora::LoraEmbeddingConfigBuilder;

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{
EmbeddingLayerLike, Lora, LoraEmbeddingConfigBuilder, NewLayers, SelectedLayers,
};
use candle_nn::{init, Embedding, Module, VarMap};
fn main() -> candle_core::Result<()> {
use std::{collections::HashMap, hash::Hash};

#[derive(PartialEq, Eq, Hash)]
enum ModelLayers {
Embed,
}
use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{EmbeddingLayerLike, Lora, NewLayers, SelectedLayers};
use candle_nn::{init, Embedding, Module, VarMap};

#[derive(Debug)]
struct Model {
embed: Box<dyn EmbeddingLayerLike>,
}
#[derive(PartialEq, Eq, Hash)]
enum ModelLayers {
Embed,
}

impl Module for Model {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.embed.forward(input)
#[derive(Debug)]
struct Model {
embed: Box<dyn EmbeddingLayerLike>,
}
}

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, conv) in new.embed {
match name {
ModelLayers::Embed => self.embed = Box::new(conv),
impl Module for Model {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.embed.forward(input)
}
}

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, mut embed) in new.embed {
match name {
ModelLayers::Embed => {
embed.merge().unwrap();
self.embed = Box::new(embed)
}
}
}
}
}
}

fn main() -> Result<()> {
let device = Device::Cpu;
let dtype = DType::F32;

Expand Down
4 changes: 2 additions & 2 deletions tests/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ fn embed() -> candle_core::Result<()> {

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, conv) in new.embed {
for (name, embed) in new.embed {
match name {
ModelLayers::Embed => self.embed = Box::new(conv),
ModelLayers::Embed => self.embed = Box::new(embed),
}
}
}
Expand Down
97 changes: 97 additions & 0 deletions tests/embed_merged.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use candle_lora::LoraEmbeddingConfigBuilder;

#[test]
fn embed() -> candle_core::Result<()> {
use std::{collections::HashMap, hash::Hash};

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{EmbeddingLayerLike, Lora, NewLayers, SelectedLayers};
use candle_nn::{init, Embedding, Module, VarMap};

#[derive(PartialEq, Eq, Hash)]
enum ModelLayers {
Embed,
}

#[derive(Debug)]
struct Model {
embed: Box<dyn EmbeddingLayerLike>,
}

impl Module for Model {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.embed.forward(input)
}
}

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, mut embed) in new.embed {
match name {
ModelLayers::Embed => {
embed.merge().unwrap();
self.embed = Box::new(embed)
}
}
}
}
}
let device = Device::Cpu;
let dtype = DType::F32;

let in_size = 10;
let hidden_size = 3;

//Create the model
let map = VarMap::new();
let embed_weight = map.get(
(in_size, hidden_size),
"embed.weight",
init::ZERO,
dtype,
&device,
)?;

let mut model = Model {
embed: Box::new(Embedding::new(embed_weight, hidden_size)),
};

let dummy_image = Tensor::zeros((2, 4), DType::U32, &device)?;

//Test the model
let output = model.forward(&dummy_image).unwrap();
println!("Output: {output:?}");

//Select layers we want to convert
let linear_layers = HashMap::new();
let conv1d_layers = HashMap::new();
let conv2d_layers = HashMap::new();
let mut embed_layers = HashMap::new();
embed_layers.insert(ModelLayers::Embed, &*model.embed);
let selected = SelectedLayers {
linear: linear_layers,
linear_config: None,
conv1d: conv1d_layers,
conv1d_config: None,
conv2d: conv2d_layers,
conv2d_config: None,
embed: embed_layers,
embed_config: Some(
LoraEmbeddingConfigBuilder::default(&device, dtype, in_size, hidden_size).build(),
),
};

//Create new LoRA layers from our layers
let new_layers = Lora::convert_model(selected);

//Custom methods to implement
model.insert_new(new_layers);

//Test the model
let lora_output = model.forward(&dummy_image).unwrap();
println!("LoRA Output: {lora_output:?}");

assert_eq!(lora_output.shape(), output.shape());

Ok(())
}
4 changes: 2 additions & 2 deletions tests/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ fn single_linear() -> candle_core::Result<()> {

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, conv) in new.linear {
for (name, linear) in new.linear {
match name {
ModelLayers::Layer => self.layer = Box::new(conv),
ModelLayers::Layer => self.layer = Box::new(linear),
}
}
}
Expand Down
Loading

0 comments on commit f46c270

Please sign in to comment.