diff --git a/src/frozenconv.rs b/src/frozenconv.rs index ca89ba1..27c15c5 100644 --- a/src/frozenconv.rs +++ b/src/frozenconv.rs @@ -107,7 +107,7 @@ impl FrozenConv2d { impl Module for FrozenConv2d { fn forward(&self, x: &Tensor) -> Result { - let x = x.conv1d( + let x = x.conv2d( &self.weight, self.config.padding, self.config.stride, @@ -118,7 +118,7 @@ impl Module for FrozenConv2d { None => Ok(x), Some(bias) => { let b = bias.dims1()?; - let bias = bias.reshape((1, b, 1))?; + let bias = bias.reshape((1, b, 1, 1))?; Ok(x.broadcast_add(&bias)?) } } diff --git a/src/loraconv2d.rs b/src/loraconv2d.rs index a699107..27ea34a 100644 --- a/src/loraconv2d.rs +++ b/src/loraconv2d.rs @@ -83,17 +83,30 @@ impl LoraConv2D { impl Module for LoraConv2D { fn forward(&self, input: &Tensor) -> Result { if let Some(scale) = self.scale { - input.conv1d( - &self - .b - .matmul(&self.a)? - .reshape(self.old.weight().shape())? - .mul(scale)?, + let x = input; + let bias = self.bias(); + let weight = &self + .b + .matmul(&self.a)? + .reshape(self.old.weight().shape())? + .mul(scale)?; + + let x = x.conv2d( + weight, self.config().padding, self.config().stride, self.config().dilation, self.config().groups, - ) + )?; + + match &bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } } else { self.old.forward(input) } diff --git a/src/main.rs b/src/main.rs index 6040ef5..a9058d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,9 +2,9 @@ use std::{collections::HashMap, hash::Hash}; use candle_core::{DType, Device, Result, Tensor}; use candle_lora::{ - loraconv1d::LoraConv1DConfig, Conv1DWithWB, Conv1dLayerLike, Lora, NewLayers, SelectedLayers, + loraconv2d::LoraConv2DConfig, Conv2DWithWB, Conv2dLayerLike, Lora, NewLayers, SelectedLayers, }; -use candle_nn::{init, Conv1d, Conv1dConfig, Module, VarMap}; +use candle_nn::{init, Conv2d, Conv2dConfig, Module, VarMap}; #[derive(PartialEq, Eq, Hash)] enum ModelLayers { @@ -13,7 +13,7 @@ enum ModelLayers { #[derive(Debug)] struct Model { - conv: Box, + conv: Box, } impl Module for Model { @@ -24,7 +24,7 @@ impl Module for Model { impl Model { fn insert_new(&mut self, new: NewLayers) { - for (name, conv) in new.conv1d { + for (name, conv) in new.conv2d { match name { ModelLayers::Conv => self.conv = Box::new(conv), } @@ -39,25 +39,19 @@ fn main() -> Result<()> { //Create the model let map = VarMap::new(); let conv_weight = map.get( - (1, 10, 10), + (1, 10, 10, 3), "conv.weight", init::DEFAULT_KAIMING_NORMAL, dtype, &device, )?; - let conv_bias = map.get( - 10, - "conv.bias", - init::DEFAULT_KAIMING_NORMAL, - dtype, - &device, - )?; + let conv_bias = map.get(3, "conv.bias", init::DEFAULT_KAIMING_NORMAL, dtype, &device)?; - let conv = Conv1DWithWB { - this: Conv1d::new( + let conv = Conv2DWithWB { + this: Conv2d::new( conv_weight.clone(), Some(conv_bias.clone()), - Conv1dConfig::default(), + Conv2dConfig::default(), ), weights: conv_weight, bias: Some(conv_bias), @@ -67,7 +61,7 @@ fn main() -> Result<()> { conv: Box::new(conv), }; - let dummy_image = Tensor::zeros((1, 10, 10), DType::F32, &device)?; + let dummy_image = Tensor::zeros((1, 10, 10, 3), DType::F32, &device)?; //Test the model let output = model.forward(&dummy_image).unwrap(); @@ -75,16 +69,16 @@ fn main() -> Result<()> { //Select layers we want to convert let linear_layers = HashMap::new(); - let mut conv1d_layers = HashMap::new(); - let conv2d_layers = HashMap::new(); - conv1d_layers.insert(ModelLayers::Conv, &*model.conv); + let conv1d_layers = HashMap::new(); + let mut conv2d_layers = HashMap::new(); + conv2d_layers.insert(ModelLayers::Conv, &*model.conv); let selected = SelectedLayers { linear: linear_layers, linear_config: None, conv1d: conv1d_layers, - conv1d_config: Some(LoraConv1DConfig::default(&device, dtype, 1, 10, 10)), + conv1d_config: None, conv2d: conv2d_layers, - conv2d_config: None, + conv2d_config: Some(LoraConv2DConfig::default(&device, dtype, 1, 10, 10)), }; //Create new LoRA layers from our layers