Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 11, 2023
1 parent ad03a78 commit 856ea2c
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions tests/embed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#[test]
fn embed() -> candle_core::Result<()> {
use std::{collections::HashMap, hash::Hash};

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{
loraembed::LoraEmbeddingConfig, EmbeddingLayerLike, Lora, NewLayers, SelectedLayers,
};
use candle_nn::{init, Embedding, Module, VarMap};

#[derive(PartialEq, Eq, Hash)]
enum ModelLayers {
Embed,
}

#[derive(Debug)]
struct Model {
embed: Box<dyn EmbeddingLayerLike>,
}

impl Module for Model {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.embed.forward(input)
}
}

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, conv) in new.embed {
match name {
ModelLayers::Embed => self.embed = Box::new(conv),
}
}
}
}

let device = Device::Cpu;
let dtype = DType::F32;

let in_size = 10;
let hidden_size = 3;

//Create the model
let map = VarMap::new();
let embed_weight = map.get(
(in_size, hidden_size),
"embed.weight",
init::ZERO,
dtype,
&device,
)?;

let mut model = Model {
embed: Box::new(Embedding::new(embed_weight, hidden_size)),
};

let dummy_image = Tensor::zeros((2, 4), DType::U32, &device)?;

//Test the model
let output = model.forward(&dummy_image).unwrap();
println!("Output: {output:?}");

//Select layers we want to convert
let linear_layers = HashMap::new();
let conv1d_layers = HashMap::new();
let conv2d_layers = HashMap::new();
let mut embed_layers = HashMap::new();
embed_layers.insert(ModelLayers::Embed, &*model.embed);
let selected = SelectedLayers {
linear: linear_layers,
linear_config: None,
conv1d: conv1d_layers,
conv1d_config: None,
conv2d: conv2d_layers,
conv2d_config: None,
embed: embed_layers,
embed_config: Some(LoraEmbeddingConfig::default(
&device,
dtype,
in_size,
hidden_size,
)),
};

//Create new LoRA layers from our layers
let new_layers = Lora::convert_model(selected);

//Custom methods to implement
model.insert_new(new_layers);

//Test the model
let lora_output = model.forward(&dummy_image).unwrap();
println!("LoRA Output: {lora_output:?}");

assert_eq!(lora_output.shape(), output.shape());

Ok(())
}

0 comments on commit 856ea2c

Please sign in to comment.