Skip to content

Commit

Permalink
Current WIP state push
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 10, 2023
1 parent 4a045c2 commit 0e96fe5
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/frozenconv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl FrozenConv2d {

impl Module for FrozenConv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(
let x = x.conv2d(
&self.weight,
self.config.padding,
self.config.stride,
Expand All @@ -118,7 +118,7 @@ impl Module for FrozenConv2d {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
Expand Down
27 changes: 20 additions & 7 deletions src/loraconv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,30 @@ impl LoraConv2D {
impl Module for LoraConv2D {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if let Some(scale) = self.scale {
input.conv1d(
&self
.b
.matmul(&self.a)?
.reshape(self.old.weight().shape())?
.mul(scale)?,
let x = input;
let bias = self.bias();
let weight = &self
.b
.matmul(&self.a)?
.reshape(self.old.weight().shape())?
.mul(scale)?;

let x = x.conv2d(
weight,
self.config().padding,
self.config().stride,
self.config().dilation,
self.config().groups,
)
)?;

match &bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
} else {
self.old.forward(input)
}
Expand Down
36 changes: 15 additions & 21 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::{collections::HashMap, hash::Hash};

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{
loraconv1d::LoraConv1DConfig, Conv1DWithWB, Conv1dLayerLike, Lora, NewLayers, SelectedLayers,
loraconv2d::LoraConv2DConfig, Conv2DWithWB, Conv2dLayerLike, Lora, NewLayers, SelectedLayers,
};
use candle_nn::{init, Conv1d, Conv1dConfig, Module, VarMap};
use candle_nn::{init, Conv2d, Conv2dConfig, Module, VarMap};

#[derive(PartialEq, Eq, Hash)]
enum ModelLayers {
Expand All @@ -13,7 +13,7 @@ enum ModelLayers {

#[derive(Debug)]
struct Model {
conv: Box<dyn Conv1dLayerLike>,
conv: Box<dyn Conv2dLayerLike>,
}

impl Module for Model {
Expand All @@ -24,7 +24,7 @@ impl Module for Model {

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, conv) in new.conv1d {
for (name, conv) in new.conv2d {
match name {
ModelLayers::Conv => self.conv = Box::new(conv),
}
Expand All @@ -39,25 +39,19 @@ fn main() -> Result<()> {
//Create the model
let map = VarMap::new();
let conv_weight = map.get(
(1, 10, 10),
(1, 10, 10, 3),
"conv.weight",
init::DEFAULT_KAIMING_NORMAL,
dtype,
&device,
)?;
let conv_bias = map.get(
10,
"conv.bias",
init::DEFAULT_KAIMING_NORMAL,
dtype,
&device,
)?;
let conv_bias = map.get(3, "conv.bias", init::DEFAULT_KAIMING_NORMAL, dtype, &device)?;

let conv = Conv1DWithWB {
this: Conv1d::new(
let conv = Conv2DWithWB {
this: Conv2d::new(
conv_weight.clone(),
Some(conv_bias.clone()),
Conv1dConfig::default(),
Conv2dConfig::default(),
),
weights: conv_weight,
bias: Some(conv_bias),
Expand All @@ -67,24 +61,24 @@ fn main() -> Result<()> {
conv: Box::new(conv),
};

let dummy_image = Tensor::zeros((1, 10, 10), DType::F32, &device)?;
let dummy_image = Tensor::zeros((1, 10, 10, 3), DType::F32, &device)?;

//Test the model
let output = model.forward(&dummy_image).unwrap();
println!("Output: {output:?}");

//Select layers we want to convert
let linear_layers = HashMap::new();
let mut conv1d_layers = HashMap::new();
let conv2d_layers = HashMap::new();
conv1d_layers.insert(ModelLayers::Conv, &*model.conv);
let conv1d_layers = HashMap::new();
let mut conv2d_layers = HashMap::new();
conv2d_layers.insert(ModelLayers::Conv, &*model.conv);
let selected = SelectedLayers {
linear: linear_layers,
linear_config: None,
conv1d: conv1d_layers,
conv1d_config: Some(LoraConv1DConfig::default(&device, dtype, 1, 10, 10)),
conv1d_config: None,
conv2d: conv2d_layers,
conv2d_config: None,
conv2d_config: Some(LoraConv2DConfig::default(&device, dtype, 1, 10, 10)),
};

//Create new LoRA layers from our layers
Expand Down

0 comments on commit 0e96fe5

Please sign in to comment.