diff --git a/src/frozenembed.rs b/src/frozenembed.rs index e69de29..782e44c 100644 --- a/src/frozenembed.rs +++ b/src/frozenembed.rs @@ -0,0 +1,42 @@ +use candle_core::{Result, Tensor}; + +use crate::EmbeddingLayerLike; + +#[derive(Debug)] +pub(crate) struct FrozenEmbedding { + embeddings: Tensor, + hidden_size: usize, +} + +impl FrozenEmbedding { + pub(crate) fn new(embeddings: &Tensor, hidden_size: usize) -> Result { + Ok(Self { + embeddings: embeddings.detach()?, + hidden_size, + }) + } + + pub(crate) fn new_from_linear(old: &dyn EmbeddingLayerLike) -> Result { + Self::new(old.embeddings(), old.hidden_size()) + } +} + +impl crate::Module for FrozenEmbedding { + fn forward(&self, indexes: &Tensor) -> Result { + let mut final_dims = indexes.dims().to_vec(); + final_dims.push(self.hidden_size); + let indexes = indexes.flatten_all()?; + let values = self.embeddings.index_select(&indexes, 0)?; + let values = values.reshape(final_dims)?; + Ok(values) + } +} + +impl EmbeddingLayerLike for FrozenEmbedding { + fn embeddings(&self) -> &Tensor { + &self.embeddings + } + fn hidden_size(&self) -> usize { + self.hidden_size + } +} diff --git a/src/lib.rs b/src/lib.rs index 3b1470a..01043ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,15 @@ //According to https://github.com/microsoft/LoRA/blob/main/loralib/layers.py #[doc = include_str!("../README.md")] use candle_core::{Shape, Tensor}; -use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Linear, Module}; +use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, Linear, Module}; use loraconv1d::{LoraConv1d, LoraConv1dConfig}; use loraconv2d::{LoraConv2d, LoraConv2dConfig}; use loralinear::{LoraLinear, LoraLinearConfig}; use std::{collections::HashMap, hash::Hash}; mod frozenconv; -mod frozenlinear; mod frozenembed; +mod frozenlinear; pub mod loraconv1d; pub mod loraconv2d; pub mod loralinear; @@ -92,20 +92,20 @@ pub trait Conv1dLayerLike: Module { #[derive(Debug)] pub struct Conv1dWithWB { - pub this: Conv1d, + pub layer: Conv1d, pub weights: Tensor, pub bias: Option, } impl Module for Conv1dWithWB { fn forward(&self, xs: &Tensor) -> candle_core::Result { - self.this.forward(xs) + self.layer.forward(xs) } } impl Conv1dLayerLike for Conv1dWithWB { fn config(&self) -> &Conv1dConfig { - self.this.config() + self.layer.config() } fn weight(&self) -> &Tensor { &self.weights @@ -123,20 +123,20 @@ pub trait Conv2dLayerLike: Module { #[derive(Debug)] pub struct Conv2dWithWB { - pub this: Conv2d, + pub layer: Conv2d, pub weights: Tensor, pub bias: Option, } impl Module for Conv2dWithWB { fn forward(&self, xs: &Tensor) -> candle_core::Result { - self.this.forward(xs) + self.layer.forward(xs) } } impl Conv2dLayerLike for Conv2dWithWB { fn config(&self) -> &Conv2dConfig { - self.this.config() + self.layer.config() } fn weight(&self) -> &Tensor { &self.weights @@ -145,3 +145,29 @@ impl Conv2dLayerLike for Conv2dWithWB { self.bias.as_ref() } } + +pub trait EmbeddingLayerLike: Module { + fn embeddings(&self) -> &Tensor; + fn hidden_size(&self) -> usize; +} + +#[derive(Debug)] +pub struct EmbeddingWithSize { + pub layer: Embedding, + pub hidden_size: usize, +} + +impl Module for EmbeddingWithSize { + fn forward(&self, xs: &Tensor) -> candle_core::Result { + self.layer.forward(xs) + } +} + +impl EmbeddingLayerLike for EmbeddingWithSize { + fn embeddings(&self) -> &Tensor { + self.layer.embeddings() + } + fn hidden_size(&self) -> usize { + self.hidden_size + } +} diff --git a/src/main.rs b/src/main.rs index 62c218e..77c0b9a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -65,7 +65,7 @@ fn main() -> Result<()> { )?; let conv = Conv2dWithWB { - this: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg), + layer: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg), weights: conv_weight, bias: Some(conv_bias), }; diff --git a/tests/conv1d.rs b/tests/conv1d.rs index dcde732..5fcbc26 100644 --- a/tests/conv1d.rs +++ b/tests/conv1d.rs @@ -56,7 +56,7 @@ fn conv1d() -> candle_core::Result<()> { )?; let conv = Conv1dWithWB { - this: Conv1d::new( + layer: Conv1d::new( conv_weight.clone(), Some(conv_bias.clone()), Conv1dConfig::default(), diff --git a/tests/conv2d.rs b/tests/conv2d.rs index f773dc8..38ace9a 100644 --- a/tests/conv2d.rs +++ b/tests/conv2d.rs @@ -66,7 +66,7 @@ fn main() -> candle_core::Result<()> { )?; let conv = Conv2dWithWB { - this: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg), + layer: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg), weights: conv_weight, bias: Some(conv_bias), };