Skip to content

Commit

Permalink
Add FrozenEmbedding
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 10, 2023
1 parent b9f1b2f commit 343e200
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 11 deletions.
42 changes: 42 additions & 0 deletions src/frozenembed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use candle_core::{Result, Tensor};

use crate::EmbeddingLayerLike;

#[derive(Debug)]
pub(crate) struct FrozenEmbedding {
embeddings: Tensor,
hidden_size: usize,
}

impl FrozenEmbedding {
pub(crate) fn new(embeddings: &Tensor, hidden_size: usize) -> Result<Self> {

Check failure on line 12 in src/frozenembed.rs

View workflow job for this annotation

GitHub Actions / Clippy

associated functions `new` and `new_from_linear` are never used

Check warning on line 12 in src/frozenembed.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

associated functions `new` and `new_from_linear` are never used

Check warning on line 12 in src/frozenembed.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

associated functions `new` and `new_from_linear` are never used

Check warning on line 12 in src/frozenembed.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

associated functions `new` and `new_from_linear` are never used

Check warning on line 12 in src/frozenembed.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

associated functions `new` and `new_from_linear` are never used

Check warning on line 12 in src/frozenembed.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

associated functions `new` and `new_from_linear` are never used

Check warning on line 12 in src/frozenembed.rs

View workflow job for this annotation

GitHub Actions / Test Suite (windows-latest, stable)

associated functions `new` and `new_from_linear` are never used
Ok(Self {
embeddings: embeddings.detach()?,
hidden_size,
})
}

pub(crate) fn new_from_linear(old: &dyn EmbeddingLayerLike) -> Result<Self> {
Self::new(old.embeddings(), old.hidden_size())
}
}

impl crate::Module for FrozenEmbedding {
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = self.embeddings.index_select(&indexes, 0)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
}

impl EmbeddingLayerLike for FrozenEmbedding {
fn embeddings(&self) -> &Tensor {
&self.embeddings
}
fn hidden_size(&self) -> usize {
self.hidden_size
}
}
42 changes: 34 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
//According to https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
#[doc = include_str!("../README.md")]
use candle_core::{Shape, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Linear, Module};
use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, Linear, Module};
use loraconv1d::{LoraConv1d, LoraConv1dConfig};
use loraconv2d::{LoraConv2d, LoraConv2dConfig};
use loralinear::{LoraLinear, LoraLinearConfig};
use std::{collections::HashMap, hash::Hash};

mod frozenconv;
mod frozenlinear;
mod frozenembed;
mod frozenlinear;
pub mod loraconv1d;
pub mod loraconv2d;
pub mod loralinear;
Expand Down Expand Up @@ -92,20 +92,20 @@ pub trait Conv1dLayerLike: Module {

#[derive(Debug)]
pub struct Conv1dWithWB {
pub this: Conv1d,
pub layer: Conv1d,
pub weights: Tensor,
pub bias: Option<Tensor>,
}

impl Module for Conv1dWithWB {
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
self.this.forward(xs)
self.layer.forward(xs)
}
}

impl Conv1dLayerLike for Conv1dWithWB {
fn config(&self) -> &Conv1dConfig {
self.this.config()
self.layer.config()
}
fn weight(&self) -> &Tensor {
&self.weights
Expand All @@ -123,20 +123,20 @@ pub trait Conv2dLayerLike: Module {

#[derive(Debug)]
pub struct Conv2dWithWB {
pub this: Conv2d,
pub layer: Conv2d,
pub weights: Tensor,
pub bias: Option<Tensor>,
}

impl Module for Conv2dWithWB {
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
self.this.forward(xs)
self.layer.forward(xs)
}
}

impl Conv2dLayerLike for Conv2dWithWB {
fn config(&self) -> &Conv2dConfig {
self.this.config()
self.layer.config()
}
fn weight(&self) -> &Tensor {
&self.weights
Expand All @@ -145,3 +145,29 @@ impl Conv2dLayerLike for Conv2dWithWB {
self.bias.as_ref()
}
}

pub trait EmbeddingLayerLike: Module {
fn embeddings(&self) -> &Tensor;
fn hidden_size(&self) -> usize;
}

#[derive(Debug)]
pub struct EmbeddingWithSize {
pub layer: Embedding,
pub hidden_size: usize,
}

impl Module for EmbeddingWithSize {
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
self.layer.forward(xs)
}
}

impl EmbeddingLayerLike for EmbeddingWithSize {
fn embeddings(&self) -> &Tensor {
self.layer.embeddings()
}
fn hidden_size(&self) -> usize {
self.hidden_size
}
}
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ fn main() -> Result<()> {
)?;

let conv = Conv2dWithWB {
this: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg),
layer: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg),
weights: conv_weight,
bias: Some(conv_bias),
};
Expand Down
2 changes: 1 addition & 1 deletion tests/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn conv1d() -> candle_core::Result<()> {
)?;

let conv = Conv1dWithWB {
this: Conv1d::new(
layer: Conv1d::new(
conv_weight.clone(),
Some(conv_bias.clone()),
Conv1dConfig::default(),
Expand Down
2 changes: 1 addition & 1 deletion tests/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fn main() -> candle_core::Result<()> {
)?;

let conv = Conv2dWithWB {
this: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg),
layer: Conv2d::new(conv_weight.clone(), Some(conv_bias.clone()), cfg),
weights: conv_weight,
bias: Some(conv_bias),
};
Expand Down

0 comments on commit 343e200

Please sign in to comment.