Skip to content

Commit

Permalink
Rework configs
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 12, 2023
1 parent dc006aa commit 99eb4cb
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 270 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ generate new layers that can easily be swapped out without forcing you to redefi
use std::{collections::HashMap, hash::Hash};

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{LinearLayerLike, Lora, LoraLinearConfigBuilder, NewLayers, SelectedLayers};
use candle_lora::{LinearLayerLike, Lora, LoraLinearConfig, NewLayers, SelectedLayers, LoraConfig};
use candle_nn::{init, Linear, Module, VarMap};

#[derive(PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -83,7 +83,7 @@ fn main() -> candle_core::Result<()> {
let embed_layers = HashMap::new();
let selected = SelectedLayers {
linear: linear_layers,
linear_config: Some(LoraLinearConfigBuilder::default(&device, dtype, 10, 10).build()),
linear_config: Some(LoraLinearConfig::new(10, 10)),
conv1d: conv1d_layers,
conv1d_config: None,
conv2d: conv2d_layers,
Expand All @@ -92,8 +92,10 @@ fn main() -> candle_core::Result<()> {
embed_config: None,
};

let loraconfig = LoraConfig::new(1, 1., None, &device, dtype);

//Create new LoRA layers from our layers
let new_layers = Lora::convert_model(selected);
let new_layers = Lora::convert_model(selected, loraconfig);

//Custom methods to implement
model.insert_new(new_layers);
Expand Down
58 changes: 45 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use candle_core::Error;
use candle_core::{DType, Device, Error};
#[doc = include_str!("../README.md")]
use candle_core::{Shape, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, Linear, Module};
use either::Either;
pub use loraconv1d::{LoraConv1d, LoraConv1dConfig, LoraConv1dConfigBuilder};
pub use loraconv2d::{LoraConv2d, LoraConv2dConfig, LoraConv2dConfigBuilder};
pub use loraembed::{LoraEmbedding, LoraEmbeddingConfig, LoraEmbeddingConfigBuilder};
pub use loralinear::{LoraLinear, LoraLinearConfig, LoraLinearConfigBuilder};
pub use loraconv1d::{LoraConv1d, LoraConv1dConfig};
pub use loraconv2d::{LoraConv2d, LoraConv2dConfig};
pub use loraembed::{LoraEmbedding, LoraEmbeddingConfig};
pub use loralinear::{LoraLinear, LoraLinearConfig};
use std::{collections::HashMap, hash::Hash};
use thiserror::Error;

Expand All @@ -24,6 +24,7 @@ impl Lora {
/// Convert the selected layers into their LoRA counterparts.
pub fn convert_model<T: Eq + PartialEq + Hash>(
selected: SelectedLayers<'_, T>,
config: LoraConfig,
) -> NewLayers<T> {
let mut new = NewLayers {
linear: HashMap::new(),
Expand All @@ -35,45 +36,76 @@ impl Lora {
for (name, layer) in selected.linear {
new.linear.insert(
name,
LoraLinear::new(layer, selected.linear_config.as_ref().unwrap()).unwrap(),
LoraLinear::new(layer, selected.linear_config.as_ref().unwrap(), &config).unwrap(),
);
}

for (name, layer) in selected.conv1d {
new.conv1d.insert(
name,
LoraConv1d::new(layer, selected.conv1d_config.as_ref().unwrap()).unwrap(),
LoraConv1d::new(layer, selected.conv1d_config.as_ref().unwrap(), &config).unwrap(),
);
}

for (name, layer) in selected.conv2d {
new.conv2d.insert(
name,
LoraConv2d::new(layer, selected.conv2d_config.as_ref().unwrap()).unwrap(),
LoraConv2d::new(layer, selected.conv2d_config.as_ref().unwrap(), &config).unwrap(),
);
}

for (name, layer) in selected.embed {
new.embed.insert(
name,
LoraEmbedding::new(layer, selected.embed_config.as_ref().unwrap()).unwrap(),
LoraEmbedding::new(layer, selected.embed_config.as_ref().unwrap(), &config)
.unwrap(),
);
}

new
}
}

pub struct LoraConfig<'a> {
rank: usize,
alpha: f64,
dropout: Option<f32>,
device: &'a Device,
dtype: DType,
}

impl<'a> LoraConfig<'a> {
/// Create a new LoRA config.
/// - `rank`: The dimensions of low-rank matrices.
/// - `alpha`: Scaling factor for the LoRA signal.
/// - `dropout`: Dropout probability for the LoRA layers.
pub const fn new(
rank: usize,
alpha: f64,
dropout: Option<f32>,
device: &'a Device,
dtype: DType,
) -> Self {
Self {
rank,
alpha,
dropout,
device,
dtype,
}
}
}

/// Each configurations is applied to all layers of its respective type
pub struct SelectedLayers<'a, T: Eq + PartialEq + Hash> {
pub linear: HashMap<T, &'a dyn LinearLayerLike>,
pub linear_config: Option<LoraLinearConfig<'a>>,
pub linear_config: Option<LoraLinearConfig>,
pub conv1d: HashMap<T, &'a dyn Conv1dLayerLike>,
pub conv1d_config: Option<LoraConv1dConfig<'a>>,
pub conv1d_config: Option<LoraConv1dConfig>,
pub conv2d: HashMap<T, &'a dyn Conv2dLayerLike>,
pub conv2d_config: Option<LoraConv2dConfig<'a>>,
pub conv2d_config: Option<LoraConv2dConfig>,
pub embed: HashMap<T, &'a dyn EmbeddingLayerLike>,
pub embed_config: Option<LoraEmbeddingConfig<'a>>,
pub embed_config: Option<LoraEmbeddingConfig>,
}

/// New layers, after conversion
Expand Down
82 changes: 21 additions & 61 deletions src/loraconv1d.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::ops::Mul;

use candle_core::{DType, Device, Module, Result, Tensor};
use candle_core::{Module, Result, Tensor};
use candle_nn::{init, Conv1d, Conv1dConfig, Dropout, VarMap};
use either::Either;

use crate::{frozenconv::FrozenConv1d, Conv1dLayerLike, Merge, MergeError, MergeErrorOrError};
use crate::{
frozenconv::FrozenConv1d, Conv1dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError,
};

#[derive(Debug)]
pub struct LoraConv1d {
Expand All @@ -17,75 +19,33 @@ pub struct LoraConv1d {
}

/// Configuration for LoraConv1d. Other configurations are inherited from the `Conv1d` struct.
pub struct LoraConv1dConfig<'a> {
rank: usize,
alpha: f64,
kernel_size: usize,
device: &'a Device,
dtype: DType,
pub struct LoraConv1dConfig {
in_channels: usize,
out_channels: usize,
dropout: Option<f32>,
}

/// Builder for LoraConv1dConfig. Call `build` to construct the config.
pub struct LoraConv1dConfigBuilder<'a> {
pub config: LoraConv1dConfig<'a>,
kernel_size: usize,
}

impl<'a> LoraConv1dConfigBuilder<'a> {
pub fn default(
device: &'a Device,
dtype: DType,
kernel_size: usize,
in_channels: usize,
out_channels: usize,
) -> Self {
LoraConv1dConfigBuilder {
config: LoraConv1dConfig {
rank: 1,
alpha: 1.,
kernel_size,
device,
dtype,
in_channels,
out_channels,
dropout: None,
},
impl LoraConv1dConfig {
pub fn new(kernel_size: usize, in_channels: usize, out_channels: usize) -> Self {
LoraConv1dConfig {
in_channels,
out_channels,
kernel_size,
}
}

/// Set the rank parameter
pub fn rank(mut self, rank: usize) -> Self {
self.config.rank = rank;
self
}

/// Set the alpha parameter
pub fn alpha(mut self, alpha: f64) -> Self {
self.config.alpha = alpha;
self
}

/// Set the dropout
pub fn dropout(mut self, prob: f32) -> Self {
self.config.dropout = Some(prob);
self
}

/// Construct the config
pub fn build(self) -> LoraConv1dConfig<'a> {
self.config
}
}

impl LoraConv1d {
pub fn new(old: &dyn Conv1dLayerLike, config: &LoraConv1dConfig) -> Result<Self> {
pub fn new(
old: &dyn Conv1dLayerLike,
conv_config: &LoraConv1dConfig,
config: &LoraConfig,
) -> Result<Self> {
let map = VarMap::new();
let a = map.get(
(
config.rank * config.kernel_size,
config.in_channels * config.kernel_size,
config.rank * conv_config.kernel_size,
conv_config.in_channels * conv_config.kernel_size,
),
"a.weight",
init::DEFAULT_KAIMING_NORMAL,
Expand All @@ -94,8 +54,8 @@ impl LoraConv1d {
)?;
let b = map.get(
(
config.out_channels / old.config().groups * config.kernel_size,
config.rank * config.kernel_size,
conv_config.out_channels / old.config().groups * conv_config.kernel_size,
config.rank * conv_config.kernel_size,
),
"b.weight",
init::ZERO,
Expand Down
78 changes: 22 additions & 56 deletions src/loraconv2d.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::ops::Mul;

use candle_core::{DType, Device, Module, Result, Tensor};
use candle_core::{Module, Result, Tensor};
use candle_nn::{init, Conv2d, Conv2dConfig, Dropout, VarMap};
use either::Either;

use crate::{frozenconv::FrozenConv2d, Conv2dLayerLike, Merge, MergeError, MergeErrorOrError};
use crate::{
frozenconv::FrozenConv2d, Conv2dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError,
};

#[derive(Debug)]
pub struct LoraConv2d {
Expand All @@ -17,72 +19,31 @@ pub struct LoraConv2d {
}

/// Configuration for LoraConv2d. Other configurations are inherited from the `Conv2d` struct.
pub struct LoraConv2dConfig<'a> {
rank: usize,
alpha: f64,
device: &'a Device,
dtype: DType,
pub struct LoraConv2dConfig {
in_channels: usize,
out_channels: usize,
dropout: Option<f32>,
}

/// Builder for LoraConv2dConfig. Call `build` to construct the config.
pub struct LoraConv2dConfigBuilder<'a> {
pub config: LoraConv2dConfig<'a>,
}

impl<'a> LoraConv2dConfigBuilder<'a> {
pub fn default(
device: &'a Device,
dtype: DType,
in_channels: usize,
out_channels: usize,
) -> Self {
LoraConv2dConfigBuilder {
config: LoraConv2dConfig {
rank: 1,
alpha: 1.,
device,
dtype,
in_channels,
out_channels,
dropout: None,
},
impl LoraConv2dConfig {
pub fn new(in_channels: usize, out_channels: usize) -> Self {
LoraConv2dConfig {
in_channels,
out_channels,
}
}

/// Set the rank parameter
pub fn rank(mut self, rank: usize) -> Self {
self.config.rank = rank;
self
}

/// Set the alpha parameter
pub fn alpha(mut self, alpha: f64) -> Self {
self.config.alpha = alpha;
self
}

/// Set the dropout
pub fn dropout(mut self, prob: f32) -> Self {
self.config.dropout = Some(prob);
self
}

/// Construct the config
pub fn build(self) -> LoraConv2dConfig<'a> {
self.config
}
}

impl LoraConv2d {
pub fn new(old: &dyn Conv2dLayerLike, config: &LoraConv2dConfig) -> Result<Self> {
pub fn new(
old: &dyn Conv2dLayerLike,
conv_config: &LoraConv2dConfig,
config: &LoraConfig,
) -> Result<Self> {
let map = VarMap::new();
let a = map.get(
(
config.rank,
config.in_channels / old.config().groups,
conv_config.in_channels / old.config().groups,
old.weight().dim(2).unwrap(),
old.weight().dim(3).unwrap(),
),
Expand All @@ -92,7 +53,12 @@ impl LoraConv2d {
config.device,
)?;
let b = map.get(
(config.out_channels, config.rank / old.config().groups, 1, 1),
(
conv_config.out_channels,
config.rank / old.config().groups,
1,
1,
),
"b.weight",
init::ZERO,
config.dtype,
Expand Down
Loading

0 comments on commit 99eb4cb

Please sign in to comment.