Skip to content

Commit

Permalink
Fix readme example
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 17, 2023
1 parent 3ce7814 commit 3801d79
Showing 1 changed file with 42 additions and 26 deletions.
68 changes: 42 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,37 @@ Together, these macros mean that `candle-lora` can be added to any `candle` mode
See an example with Llama [here](examples/llama). I will add a training example soon!

```rust
use candle_core::{DType, Device, Module, Result, Tensor};
use candle_lora::{LinearLayerLike, LoraConfig, LoraLinearConfig};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
use candle_nn::{init, Linear, VarMap};
use std::{collections::HashMap, hash::Hash};

#[replace_layer_fields]
#[derive(AutoLoraConvert, Debug)]
use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{
LinearLayerLike, Lora, LoraConfig, LoraLinearConfig, NewLayers, SelectedLayersBuilder,
};
use candle_nn::{init, Linear, Module, VarBuilder, VarMap};

#[derive(PartialEq, Eq, Hash)]
enum ModelLayers {
Layer,
}

#[derive(Debug)]
struct Model {
a: Linear,
b: i32,
layer: Box<dyn LinearLayerLike>,
}

impl Module for Model {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.a.forward(input)
self.layer.forward(input)
}
}

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, linear) in new.linear {
match name {
ModelLayers::Layer => self.layer = Box::new(linear),
}
}
}
}

Expand All @@ -58,27 +74,27 @@ fn main() {
.unwrap();

let mut model = Model {
a: Box::new(Linear::new(layer_weight.clone(), None)),
b: 1,
layer: Box::new(Linear::new(layer_weight.clone(), None)),
};

let loraconfig = LoraConfig::new(1, 1., None, &device, dtype);
model.get_lora_model(
loraconfig,
Some(LoraLinearConfig::new(10, 10)),
None,
None,
None,
);
let mut linear_layers = HashMap::new();
linear_layers.insert(ModelLayers::Layer, &*model.layer);
let selected = SelectedLayersBuilder::new()
.add_linear_layers(linear_layers, LoraLinearConfig::new(10, 10))
.build();

let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap();
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, dtype, &device);

//Test the model
let digit = model.forward(&dummy_image).unwrap();
println!("Output: {digit:?}");
let loraconfig = LoraConfig::new(1, 1., None);

println!("{:?}", model.a);
println!("{:?}", model.b);
}
let new_layers = Lora::convert_model(selected, loraconfig, &vb);

model.insert_new(new_layers);

let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap();

let lora_output = model.forward(&dummy_image).unwrap();
println!("Output: {lora_output:?}");
}
```

0 comments on commit 3801d79

Please sign in to comment.