Skip to content

Commit

Permalink
Implement Embedding for LoRA, fix Linear impl
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 11, 2023
1 parent f58ca18 commit ad03a78
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 108 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ fn main() -> candle_core::Result<()> {
linear_layers.insert(ModelLayers::Layer, &*model.layer);
let conv1d_layers = HashMap::new();
let conv2d_layers = HashMap::new();
let embed_layers = HashMap::new();
let selected = SelectedLayers {
linear: linear_layers,
linear_config: Some(LoraLinearConfig::default(&device, dtype, 10, 10)),
conv1d: conv1d_layers,
conv1d_config: None,
conv2d: conv2d_layers,
conv2d_config: None,
embed: embed_layers,
embed_config: None,
};

//Create new LoRA layers from our layers
Expand Down
2 changes: 1 addition & 1 deletion src/frozenembed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl FrozenEmbedding {
})
}

pub(crate) fn new_from_linear(old: &dyn EmbeddingLayerLike) -> Result<Self> {
pub(crate) fn new_from_embed(old: &dyn EmbeddingLayerLike) -> Result<Self> {
Self::new(old.embeddings(), old.hidden_size())
}
}
Expand Down
32 changes: 16 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use candle_core::{Shape, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, Linear, Module};
use loraconv1d::{LoraConv1d, LoraConv1dConfig};
use loraconv2d::{LoraConv2d, LoraConv2dConfig};
use loraembed::{LoraEmbedding, LoraEmbeddingConfig};
use loralinear::{LoraLinear, LoraLinearConfig};
use std::{collections::HashMap, hash::Hash};

Expand All @@ -12,8 +13,8 @@ mod frozenembed;
mod frozenlinear;
pub mod loraconv1d;
pub mod loraconv2d;
pub mod loralinear;
pub mod loraembed;
pub mod loralinear;

pub struct Lora;

Expand All @@ -25,6 +26,7 @@ impl Lora {
linear: HashMap::new(),
conv1d: HashMap::new(),
conv2d: HashMap::new(),
embed: HashMap::new(),
};

for (name, layer) in selected.linear {
Expand All @@ -48,6 +50,13 @@ impl Lora {
);
}

for (name, layer) in selected.embed {
new.embed.insert(
name,
LoraEmbedding::new(layer, selected.embed_config.as_ref().unwrap()).unwrap(),
);
}

new
}
}
Expand All @@ -59,12 +68,15 @@ pub struct SelectedLayers<'a, T: Eq + PartialEq + Hash> {
pub conv1d_config: Option<LoraConv1dConfig<'a>>,
pub conv2d: HashMap<T, &'a dyn Conv2dLayerLike>,
pub conv2d_config: Option<LoraConv2dConfig<'a>>,
pub embed: HashMap<T, &'a dyn EmbeddingLayerLike>,
pub embed_config: Option<LoraEmbeddingConfig<'a>>,
}

pub struct NewLayers<T: Eq + PartialEq + Hash> {
pub linear: HashMap<T, LoraLinear>,
pub conv1d: HashMap<T, LoraConv1d>,
pub conv2d: HashMap<T, LoraConv2d>,
pub embed: HashMap<T, LoraEmbedding>,
}

pub trait LinearLayerLike: Module {
Expand Down Expand Up @@ -152,23 +164,11 @@ pub trait EmbeddingLayerLike: Module {
fn hidden_size(&self) -> usize;
}

#[derive(Debug)]
pub struct EmbeddingWithSize {
pub layer: Embedding,
pub hidden_size: usize,
}

impl Module for EmbeddingWithSize {
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
self.layer.forward(xs)
}
}

impl EmbeddingLayerLike for EmbeddingWithSize {
impl EmbeddingLayerLike for Embedding {
fn embeddings(&self) -> &Tensor {
self.layer.embeddings()
self.embeddings()
}
fn hidden_size(&self) -> usize {
self.hidden_size
self.embeddings().dim(1).unwrap() //Reason: 2nd dim is always the hidden
}
}
81 changes: 36 additions & 45 deletions src/loraembed.rs
Original file line number Diff line number Diff line change
@@ -1,106 +1,97 @@
use std::ops::Mul;
use candle_core::{DType, Device, Module, Result, Tensor};
use candle_nn::{init, Embedding, VarMap};

use candle_core::{DType, Device, Module, Result, Shape, Tensor};
use candle_nn::{init, Dropout, VarMap};

use crate::{frozenlinear::FrozenLinear, LinearLayerLike};
use crate::{frozenembed::FrozenEmbedding, EmbeddingLayerLike};

#[derive(Debug)]
pub struct LoraLinear {
old: FrozenLinear,
pub struct LoraEmbedding {
old: FrozenEmbedding,
a: Tensor,
b: Tensor,
scale: Option<f64>,
dropout: Option<Dropout>,
}

pub struct LoraLinearConfig<'a> {
pub struct LoraEmbeddingConfig<'a> {
pub rank: usize,
pub alpha: f64,
pub dropout: Option<f32>,
pub device: &'a Device,
pub dtype: DType,
pub in_features: usize,
pub out_features: usize,
pub num_embeddings: usize,
pub embedding_dim: usize,
}

impl<'a> LoraLinearConfig<'a> {
impl<'a> LoraEmbeddingConfig<'a> {
pub fn default(
device: &'a Device,
dtype: DType,
in_features: usize,
out_features: usize,
num_embeddings: usize,
embedding_dim: usize,
) -> Self {
LoraLinearConfig {
LoraEmbeddingConfig {
rank: 1,
alpha: 1.,
dropout: Some(0.),
device,
dtype,
in_features,
out_features,
num_embeddings,
embedding_dim,
}
}
}

impl LoraLinear {
pub fn new(old: &dyn LinearLayerLike, config: &LoraLinearConfig) -> Result<Self> {
impl LoraEmbedding {
pub fn new(old: &dyn EmbeddingLayerLike, config: &LoraEmbeddingConfig) -> Result<Self> {
let map = VarMap::new();
let a = map.get(
(config.rank, config.in_features),
(config.rank, config.num_embeddings),
"a.weight",
init::DEFAULT_KAIMING_NORMAL,
init::ZERO,
config.dtype,
config.device,
)?;
let b = map.get(
(config.out_features, config.rank),
(config.embedding_dim, config.rank),
"b.weight",
init::ZERO,
config.dtype,
config.device,
)?;

Ok(LoraLinear {
old: FrozenLinear::new_from_linear(old)?,
Ok(LoraEmbedding {
old: FrozenEmbedding::new_from_embed(old)?,
a,
b,
scale: if config.rank > 0 {
Some(config.alpha / config.rank as f64)
} else {
None
},
dropout: config.dropout.map(Dropout::new),
})
}
}

impl Module for LoraLinear {
impl Module for LoraEmbedding {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
//No fan_in_fan_out so no weight.transpose(0,1)
let mut result = self.old.forward(input)?;
if let Some(scale) = self.scale {
if self.dropout.is_some() {
result = (result + self.dropout.as_ref().unwrap().forward(input, true)?)?;
} else {
result = (result + input)?;
}
result = (&result+result.matmul(&self.a.transpose(0, 1)?)?)?;
result = (&result+result.matmul(&self.b.transpose(0, 1)?)?)?;
result = (&result+result.clone().mul(scale)?)?;
let weight = self.a.transpose(0, 1)?;
let weight = weight.reshape(weight.shape())?; //Get contiguous
let hidden = weight.dim(1)?;

let embed = Embedding::new(weight, hidden);
let after_a = embed.forward(input)?;

result = (result + after_a.broadcast_matmul(&self.b.transpose(0, 1)?)?)?;
result = (result * scale)?;
}
Ok(result)
}
}

impl LinearLayerLike for LoraLinear {
fn bias(&self) -> Option<&Tensor> {
self.old.bias()
}
fn weight(&self) -> &Tensor {
self.old.weight()
impl EmbeddingLayerLike for LoraEmbedding {
fn embeddings(&self) -> &Tensor {
self.old.embeddings()
}
fn shape(&self) -> &Shape {
self.old.shape()
fn hidden_size(&self) -> usize {
self.old.hidden_size()
}
}
7 changes: 4 additions & 3 deletions src/loralinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ impl Module for LoraLinear {
} else {
result = (result + input)?;
}
result = result.matmul(&self.a.transpose(0, 1)?)?;
result = result.matmul(&self.b.transpose(0, 1)?)?;
result = result.mul(scale)?;
result = result.broadcast_add(
&result.matmul(&self.b.broadcast_matmul(&self.a.matmul(&result)?)?)?,
)?;
result = result.broadcast_add(&result.clone().mul(scale)?)?;
}
Ok(result)
}
Expand Down
66 changes: 23 additions & 43 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,31 @@ use std::{collections::HashMap, hash::Hash};

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{
loraconv2d::LoraConv2dConfig, Conv2dLayerLike, Conv2dWithWB, Lora, NewLayers, SelectedLayers,
loraembed::LoraEmbeddingConfig, EmbeddingLayerLike, Lora, NewLayers, SelectedLayers,
};
use candle_nn::{init, Conv2d, Conv2dConfig, Module, VarMap};
use candle_nn::{init, Embedding, Module, VarMap};

#[derive(PartialEq, Eq, Hash)]
enum ModelLayers {
Conv,
Embed,
}

#[derive(Debug)]
struct Model {
conv: Box<dyn Conv2dLayerLike>,
embed: Box<dyn EmbeddingLayerLike>,
}

impl Module for Model {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.conv.forward(input)
self.embed.forward(input)
}
}

impl Model {
fn insert_new(&mut self, new: NewLayers<ModelLayers>) {
for (name, conv) in new.conv2d {
for (name, conv) in new.embed {
match name {
ModelLayers::Conv => self.conv = Box::new(conv),
ModelLayers::Embed => self.embed = Box::new(conv),
}
}
}
Expand All @@ -36,46 +36,24 @@ fn main() -> Result<()> {
let device = Device::Cpu;
let dtype = DType::F32;

let out_channels = 10;
let in_channels = 10;
let kernel = 2;

let cfg = Conv2dConfig::default();
let in_size = 10;
let hidden_size = 3;

//Create the model
let map = VarMap::new();
let conv_weight = map.get(
(
out_channels,
in_channels / cfg.groups, //cfg.groups in this case are 1
kernel,
kernel,
),
"conv.weight",
init::DEFAULT_KAIMING_NORMAL,
dtype,
&device,
)?;
let conv_bias = map.get(
out_channels,
"conv.bias",
init::DEFAULT_KAIMING_NORMAL,
let embed_weight = map.get(
(in_size, hidden_size),
"embed.weight",
init::ZERO,
dtype,
&device,
)?;

let conv = Conv2dWithWB {
layer: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg),
weights: conv_weight,
bias: Some(conv_bias),
};

let mut model = Model {
conv: Box::new(conv),
embed: Box::new(Embedding::new(embed_weight, hidden_size)),
};

let shape = [2, in_channels, 20, 20]; //(BS, K, X, Y)
let dummy_image = Tensor::zeros(&shape, DType::F32, &device)?;
let dummy_image = Tensor::zeros((2, 4), DType::U32, &device)?;

//Test the model
let output = model.forward(&dummy_image).unwrap();
Expand All @@ -84,20 +62,22 @@ fn main() -> Result<()> {
//Select layers we want to convert
let linear_layers = HashMap::new();
let conv1d_layers = HashMap::new();
let mut conv2d_layers = HashMap::new();
conv2d_layers.insert(ModelLayers::Conv, &*model.conv);
let conv2d_layers = HashMap::new();
let mut embed_layers = HashMap::new();
embed_layers.insert(ModelLayers::Embed, &*model.embed);
let selected = SelectedLayers {
linear: linear_layers,
linear_config: None,
conv1d: conv1d_layers,
conv1d_config: None,
conv2d: conv2d_layers,
conv2d_config: Some(LoraConv2dConfig::default(
conv2d_config: None,
embed: embed_layers,
embed_config: Some(LoraEmbeddingConfig::default(
&device,
dtype,
kernel,
in_channels,
out_channels,
in_size,
hidden_size,
)),
};

Expand Down
Loading

0 comments on commit ad03a78

Please sign in to comment.