Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jun 16, 2024
1 parent 2ba5058 commit d586d98
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions mistralrs-core/src/vision_models/idefics2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ impl VisionEmbeddings {
let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;

let patch_embeds = self.patch_embedding.forward(pixel_values)?;

println!("saving `patch_embeds`");
patch_embeds
.to_dtype(DType::F32)?
.to_device(&Device::Cpu)?
.write_npy("pixel_values_probe_m.npy")?;
println!("saved it");

let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;

let (max_nb_patches_h, max_nb_patches_w) =
Expand Down Expand Up @@ -895,6 +903,9 @@ impl Idefics2 {
) -> Result<Tensor> {
let input_embeds = if let Some(pixel_values) = pixel_values {
// == START VISUAL INPUTS INTEGRATION ==
let pixel_values = Tensor::read_npy("../pixel_values_start.npy")?
.to_dtype(pixel_values.dtype())?
.to_device(pixel_values.device())?;
let (batch_size, num_images, _, _, _) = pixel_values.dims5()?;
let mut s = vec![batch_size * num_images];
s.extend(pixel_values.dims()[2..].to_vec());
Expand Down Expand Up @@ -922,10 +933,6 @@ impl Idefics2 {
}
}
let pixel_values = Tensor::cat(&batches, 0)?;
pixel_values
.to_dtype(DType::F32)?
.to_device(&Device::Cpu)?
.write_npy("pixel_values_selected_m.npy")?;

// Vision attention mask
let pixel_attention_mask = if let Some(pixel_attention_mask) = pixel_attention_mask {
Expand All @@ -934,7 +941,18 @@ impl Idefics2 {
pixel_attention_mask.dims()[2],
pixel_attention_mask.dims()[3],
))?;
pixel_attention_mask.index_select(&real_images_inds, 0)?
let mut batches = Vec::new();
for (batch, use_it) in pixel_attention_mask
.chunk(pixel_attention_mask.dim(0)?, 0)?
.iter()
.zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
{
let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
if use_it {
batches.push(batch.clone());
}
}
Tensor::cat(&batches, 0)?
} else {
Tensor::ones(
(
Expand Down

0 comments on commit d586d98

Please sign in to comment.