Skip to content

Commit

Permalink
Use new methods to remove indirection
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 11, 2023
1 parent 4fe27dd commit 55a5bbe
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 54 deletions.
42 changes: 8 additions & 34 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,15 @@ pub trait Conv1dLayerLike: Module {
fn config(&self) -> &Conv1dConfig;
}

#[derive(Debug)]
pub struct Conv1dWithWB {
pub layer: Conv1d,
pub weights: Tensor,
pub bias: Option<Tensor>,
}

impl Module for Conv1dWithWB {
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
self.layer.forward(xs)
}
}

impl Conv1dLayerLike for Conv1dWithWB {
impl Conv1dLayerLike for Conv1d {
fn config(&self) -> &Conv1dConfig {
self.layer.config()
self.config()
}
fn weight(&self) -> &Tensor {
&self.weights
self.weight()
}
fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
self.bias()
}
}

Expand All @@ -138,28 +125,15 @@ pub trait Conv2dLayerLike: Module {
fn config(&self) -> &Conv2dConfig;
}

#[derive(Debug)]
pub struct Conv2dWithWB {
pub layer: Conv2d,
pub weights: Tensor,
pub bias: Option<Tensor>,
}

impl Module for Conv2dWithWB {
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
self.layer.forward(xs)
}
}

impl Conv2dLayerLike for Conv2dWithWB {
impl Conv2dLayerLike for Conv2d {
fn config(&self) -> &Conv2dConfig {
self.layer.config()
self.config()
}
fn weight(&self) -> &Tensor {
&self.weights
self.weight()
}
fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
self.bias()
}
}

Expand Down
15 changes: 4 additions & 11 deletions tests/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ fn conv1d() -> candle_core::Result<()> {

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{
loraconv1d::LoraConv1dConfig, Conv1dLayerLike, Conv1dWithWB, Lora, NewLayers,
SelectedLayers,
loraconv1d::LoraConv1dConfig, Conv1dLayerLike, Lora, NewLayers, SelectedLayers,
};
use candle_nn::{init, Conv1d, Conv1dConfig, Module, VarMap};

Expand Down Expand Up @@ -55,18 +54,12 @@ fn conv1d() -> candle_core::Result<()> {
&device,
)?;

let conv = Conv1dWithWB {
layer: Conv1d::new(
let mut model = Model {
conv: Box::new(Conv1d::new(
conv_weight.clone(),
Some(conv_bias.clone()),
Conv1dConfig::default(),
),
weights: conv_weight,
bias: Some(conv_bias),
};

let mut model = Model {
conv: Box::new(conv),
)),
};

let dummy_image = Tensor::zeros((1, 10, 10), DType::F32, &device)?;
Expand Down
15 changes: 6 additions & 9 deletions tests/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ fn conv2d() -> candle_core::Result<()> {

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{
loraconv2d::LoraConv2dConfig, Conv2dLayerLike, Conv2dWithWB, Lora, NewLayers,
SelectedLayers,
loraconv2d::LoraConv2dConfig, Conv2dLayerLike, Lora, NewLayers, SelectedLayers,
};
use candle_nn::{init, Conv2d, Conv2dConfig, Module, VarMap};

Expand Down Expand Up @@ -66,14 +65,12 @@ fn conv2d() -> candle_core::Result<()> {
&device,
)?;

let conv = Conv2dWithWB {
layer: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg),
weights: conv_weight,
bias: Some(conv_bias),
};

let mut model = Model {
conv: Box::new(conv),
conv: Box::new(Conv2d::new(
conv_weight.clone(),
Some(conv_bias.clone()),
cfg,
)),
};

let shape = [2, in_channels, 20, 20]; //(BS, K, X, Y)
Expand Down

0 comments on commit 55a5bbe

Please sign in to comment.