Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate embeddings via prediction loop #56

Merged
merged 6 commits into from
Dec 4, 2023
Merged

Generate embeddings via prediction loop #56

merged 6 commits into from
Dec 4, 2023

Conversation

weiji14
Copy link
Contributor

@weiji14 weiji14 commented Nov 29, 2023

What I am changing

  • Generate embeddings from a pretrained ViT encoder, stored in an npy file format.

How I did it

Illustration of how the raw embeddings are turned into the final embeddings of shape (1, 768).

Excalidraw link: https://excalidraw.com/#json=IDteKVYDAHd05wT-rCR6K,vmbiRIb5ucGiXP6R1idu4w

TODO in this PR:

  • Implement predict_dataloader and predict_step
  • Add a unit test
  • Refactor predict_dataloader to not rely on the validation datapipe
  • Compute mean of the raw patch embeddings

TODO in the future:

  • Properly document how to generate/save embeddings and load them (see instructions below)
  • Put some sort of spatiotemporal metadata in the filename of the output embedding
  • Refactor to save only encoder weights to checkpoint instead of entire encoder/decoder network

How you can test it

  • Locally, download some GeoTIFF data into the data/ folder, and then run:
python trainer.py fit --trainer.max_epochs=10 --trainer.precision=bf16-mixed --data.data_path=data/56HKH --data.num_workers=4  # train the model
python trainer.py predict --ckpt_path=checkpoints/last.ckpt --data.batch_size=1 --trainer.limit_predict_batches=1  # generate embeddings
  • This should produce an embedding_0.npy file under the data/embeddings/ folder. Sample files (need to unzip):

  • Extra configuration options can be found using python trainer.py predict --help

To load the embeddings from the npy file:

import numpy as np

array: np.ndarray = np.load(file="embedding_0.npy")
assert array.shape == (1, 768)

Related Issues

Towards #3

Implement the embedding generator in the LightningModule's predict_step. The embeddings are tensor arrays that are saved to a .npy file in the data/embeddings/ folder. Input data is retrieved from the predict_dataloader, which is currently using the validation datapipe rather than a dedicated datapipe. Have documented how to generate the embedding output file using LightningCLI on the main README.md file. Also added a unit test to ensure that saving and loading from an embedding_0.npy file works.
@weiji14 weiji14 added the model-architecture Pull requests about the neural network model architecture label Nov 29, 2023
@weiji14 weiji14 added this to the v0 Release milestone Nov 29, 2023
@weiji14 weiji14 self-assigned this Nov 29, 2023
src/model_vit.py Outdated
Comment on lines 142 to 146
# Get embeddings generated from encoder
embeddings: torch.Tensor = outputs_encoder.last_hidden_state
assert embeddings.shape == torch.Size(
[self.B, 17, 768] # (batch_size, sequence_length, hidden_size)
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srmsoumya, if you have time, could you try to see why the outputs of the encoder have a shape like (32, 17, 768), or (batch_size, sequence_length, hidden_size)? Specifically, I'm not sure what the sequence_length (size: 17) dimension is about, and couldn't quite figure it out from reading https://huggingface.co/docs/transformers/model_doc/vit_mae. More just something to understand, since we'll be moving to your MAE implementation in #47 later.

Copy link
Collaborator

@srmsoumya srmsoumya Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weiji14 we are masking out 75% of the patches from the image. Given our image size is 256 x 256 and we use a patch size of 32, that will give us 256 / 32 => 8 x 8 => 64 patches. Masking out 75% of these patches gives us 0.25 x 64 => 16 patches. Adding 1 extra cls token gives us a total of 17 patches to input into the transformer portion of the encoder. This results in a batch size of batch_size: 32 x number of unmasked patches: 17 x patch embedding: 768. It is also the output we get from the encoder.

When creating the embeddings from the encoder of the model, we should switch-off the masking strategy.

Copy link
Contributor Author

@weiji14 weiji14 Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful explanation! Thanks @srmsoumya, I've disabled the masking in the predict_step now at commit f09d2e7, and the output is now (32, 65, 768) as expected. Here's a new sample embedding for one image: embedding_0.npy.zip [array shape: (1, 65, 768)]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, that looks just right!

@weiji14 For embeddings per image, the general practice is either to pick the 1st embedding vector i.e batch_size x unmasked_patches[:1] x embedding_dim or take a mean of the remaining vectors i.e (batch_size x unmasked_patches[1:] x embedding_dim).mean(dim=1). Doing this will give us embeddings of size batch_size x embedding_dim which is a single vector representing an image.
First vector represents the cls token, which should represent the feature embedding of the image (this is a borrowed concept from BERT) or the mean of remaining vectors should also work (which is what I am trying to implement with vit-pytorch, makes code less complicated).

Copy link
Contributor Author

@weiji14 weiji14 Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, do you have any papers describing when either the cls token or mean of the patch embeddings are used? Also @leothomas, do you have any insights on which method of collapsing the embeddings into a single 1d vector was used for the similarity search/vector database projects you've worked on? Just want some extra context before deciding on which way to go.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there!

Just to clarify, the patches are are 32x32 (pixels) subsections of the overall image, correct? Why are we masking 75% of them? Is that due to cloud cover, or seeking patches only over land?

In regards to the cls token, it seems to be a classification token - which represents the entire image.

The theory behind averaging the embeddings is that there is a very high liklyhood (but not a guarantee) that the averages of 2 similar collections will be similar and the averages of 2 difference collections will be different. It seems that this has mostly been researched in the case of text embeddings, where the exact order of words may matter less than in the case of image patches. I suspect that averaging the embeddings for each patch may lose some of the physical relationships of the overall image, but would be very interesting to compare and contrast the two.

If we're going to use something like pgvector, we can easily have both types of embeddings colocated in the database and build a partial index for each!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, the patches are are 32x32 (pixels) subsections of the overall image, correct? Why are we masking 75% of them? Is that due to cloud cover, or seeking patches only over land?

Yes, each patch are 32x32 pixels. We were masking 75% of the patches because that's how a Masked Autoencoder is trained (see previous PR at #37), but for inference/prediction, we shouldn't apply the mask.

In regards to the cls token, it seems to be a classification token - which represents the entire image.

The theory behind averaging the embeddings is that there is a very high liklyhood (but not a guarantee) that the averages of 2 similar collections will be similar and the averages of 2 difference collections will be different. It seems that this has mostly been researched in the case of text embeddings, where the exact order of words may matter less than in the case of image patches. I suspect that averaging the embeddings for each patch may lose some of the physical relationships of the overall image, but would be very interesting to compare and contrast the two.

If we're going to use something like pgvector, we can easily have both types of embeddings colocated in the database and build a partial index for each!

That sounds good actually. We could save two files, a cls_token embedding (1x768), and a patch embedding (either 1x768, or 64x768)?

Copy link
Contributor Author

@weiji14 weiji14 Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for fun, I took a look at the embedding in their raw un-averaged form.

import pandas as pd
import numpy as np

embedding: np.ndarray = np.load(file="data/embeddings/embedding_0.npy")
df: pd.DataFrame = pd.DataFrame(data=embedding.squeeze())
df.shape  # (65, 768)

# Get descriptive statistics on each of the 768 columns
df[1:].describe()

# Heatmap of embeddings (first row is cls_token)
df.style.background_gradient(axis="columns", cmap="Greens")

Row 0 is the cls_token, and row 1-64 is each patch. Columns are the 768 embeddings.

Descriptive stats:

image

Heatmap:

image

Scrolling through the 768 columns, I don't think I saw much in terms of outliers, the values seem pretty consistent within a column (standard deviation is usually <0.01), so should be ok to just use the mean I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my intuition is correct, is not that surprising that most tiles are "semantically flat" since most tiles will be "on thing" all forest, or grass, or mountain ... The test for the need of richer embeddings would be to pick up examples with semantically rich, or even semantic polarization, like an image with both land and water, or city and forest.

Also since we train with a MAE, I can see how each semantic patch will actually learn to include the expected semantics of the surrounding patches within the chip, not its own concent; so that the inference can recreate the missing bits, making the stdev within a patch smaller.

Am I making sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, it would be good to explore what the embeddings look like for tiles that have more diverse land cover types. I've done just that by picking the most diverse tile we have, details at #35 (comment)

Also since we train with a MAE, I can see how each semantic patch will actually learn to include the expected semantics of the surrounding patches within the chip, not its own concent; so that the inference can recreate the missing bits, making the stdev within a patch smaller.

Am I making sense?

Yes, the 1x768 embedding generated from each 32x32 patch actually contains information from the other patches. More details at #67, and we can continue the discussion there!

Previously, 75% of the patches, or 48 out of a total of 64 were masked out, leaving 16 patches plus 1 cls_token = 17 sequences. Disabling the mask gives 64 + 1 cls_token = 65 sequences. Moved some assert statements with a fixed sequence_length dim from the forward function to the training_step. Also updated the unit test to ensure output embeddings have a shape like (batch_size, 65, 768).
Refactoring the setup method in the LightningDataModule to not do a random split on the predict stage. I.e. just do the GeoTIFF to torch.Tensor conversion directly, followed by batching and collating.
Need to explicitly pass an argument to stage in the test_geotiffdatapipemodule unit test. Testing both the fit and predict stages.
@weiji14 weiji14 marked this pull request as ready for review November 30, 2023 02:03
Make sure that the generated embeddings do not have NaN values in them.
@weiji14
Copy link
Contributor Author

weiji14 commented Nov 30, 2023

Gonna leave this for a day or so for review. There are a few nice-to-haves (e.g. better documentation on how to read the embeddings, a better filename that includes some spatiotemporal metadata, etc), but those can be done in follow-up PRs.

Instead of saving embeddings of shape (1, 65, 768), save out embeddings of shape (1, 768) instead. Done by taking the mean along the sequence_length dim, except for the cls_token part (first index in the 65).
Copy link
Collaborator

@srmsoumya srmsoumya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@weiji14
Copy link
Contributor Author

weiji14 commented Dec 4, 2023

Thanks for reviewing @srmsoumya, and everyone else for the comments. I'll merge this in now, noting that the code here generates (1, 768) shape averaged embeddings rather than the (1, 65, 768) shape raw embeddings. We can revise this later if we decide that we do want the raw embeddings.

@weiji14 weiji14 merged commit 97a9a36 into main Dec 4, 2023
2 checks passed
@weiji14 weiji14 deleted the gen-embeddings branch December 4, 2023 23:51
brunosan pushed a commit that referenced this pull request Dec 27, 2023
* 🍻 Generate embeddings via prediction loop

Implement the embedding generator in the LightningModule's predict_step. The embeddings are tensor arrays that are saved to a .npy file in the data/embeddings/ folder. Input data is retrieved from the predict_dataloader, which is currently using the validation datapipe rather than a dedicated datapipe. Have documented how to generate the embedding output file using LightningCLI on the main README.md file. Also added a unit test to ensure that saving and loading from an embedding_0.npy file works.

* 🐛 Disable masking of patches on predict_step

Previously, 75% of the patches, or 48 out of a total of 64 were masked out, leaving 16 patches plus 1 cls_token = 17 sequences. Disabling the mask gives 64 + 1 cls_token = 65 sequences. Moved some assert statements with a fixed sequence_length dim from the forward function to the training_step. Also updated the unit test to ensure output embeddings have a shape like (batch_size, 65, 768).

* ♻️ Refactor LightningDataModule to not do random split on predict

Refactoring the setup method in the LightningDataModule to not do a random split on the predict stage. I.e. just do the GeoTIFF to torch.Tensor conversion directly, followed by batching and collating.

* ✅ Test predict stage in geotiffdatamodule

Need to explicitly pass an argument to stage in the test_geotiffdatapipemodule unit test. Testing both the fit and predict stages.

* 👔 Ensure that embeddings have no NaN values

Make sure that the generated embeddings do not have NaN values in them.

* 🗃️ Take mean of the embeddings along sequence_length dim

Instead of saving embeddings of shape (1, 65, 768), save out embeddings of shape (1, 768) instead. Done by taking the mean along the sequence_length dim, except for the cls_token part (first index in the 65).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model-architecture Pull requests about the neural network model architecture
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants