Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to output raw patch embeddings #133

Merged
merged 14 commits into from
Jan 29, 2024
Merged

Conversation

yellowcap
Copy link
Member

@yellowcap yellowcap commented Jan 23, 2024

This PR adds an option to the Lightning CLI to output patch level embeddings. I.e. one embedding per patch in each chip. The band group dimension is reduced and, so the patch embeddings are averages over the band groups.

This adds two args to the CLI output_patch_embeddings, and shuffle. Because for the patch embeddings we need to ensure that shuffle is off, while it is on by default. See also #123.

Updated the documentation to explain these changes.

This PR Closes #130

)
if self.hparams.output_patch_embeddings:
embeddings_raw = rearrange(
embeddings_raw[:, :-2, :], "b (w h g) s -> b (w h s) g", w=16, h=16, g=6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the tensor matters, in this case the ouput from the encoder i.e embeddings_raw is of shape batch x (group x num_spatial_patches) x embedding_dims.

So, the einops operation should be, b (g l) d -> b g l d - you can check this notebook for reference: https://github.com/Clay-foundation/model/blob/docs/model/docs/clay-v0-visualization.ipynb

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but please review again to make sure I got it right this time. Related to this, do you think this is a good way to unravel the patch embeddings?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned this below at #133 (comment), but can we just keep the unravelled shape (B, 256, 768), or (B, 1536, 768)? Less work for the downstream user since they won't need to figure out how to unravel the tensor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes absolutely. Don't know why I thought it had to be a 1d array...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, do we want to store a (256, 768) embedding tensor in a single row, or split it into 256 rows of 768-length embeddings? I'm reading the original thread at #127, and it sounds like we need to enable searching on the 32x32 patch level instead of 256x256 chips, which might mean storing a row for each patch?

We can store 2D arrays in GeoParquet, but I'm not sure if vector databases allow indexing 2D arrays, or if we need to make it 1D. Note that this would increase the embedding file size significantly (x256), and the vector database indexing will be much slower. But if these raw embeddings are only meant to be generated on an ad-hoc basis for small locations, and not for similarity search applications, it should be fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that too, but that would also imply that we have to track the MGRS tile and the source path 256x more. That seems ineficient. My thinking was that the splitting into rows can happen at the moment on ingestion to a vector search DB or when analysing if required. For "local manual inspection" situations, having the patches in a single array is useful too. So I vote to keep it at one row in the gpd and let the separation be a downstream issue.

src/model_clay.py Outdated Show resolved Hide resolved
src/model_clay.py Outdated Show resolved Hide resolved
The `output_patch_embeddings` flag determines how the embeddings are calculated.
If `False`, one average embedding per MGRS tile of size 768 will be created. If
`True`, the embeddings will be kept at the patch level. The embedding array will
be of size 16 * 16 * 768, representing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dangling sentence, representing what? Also, should be size 16*16, 768 no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated wording to adopt to the embedding levels input.

docs/model_embeddings.md Outdated Show resolved Hide resolved
output embedding of shape (B, 768).
2. By default, the mean or average is taken across the 1536 patch dimension,
yielding an output embedding of shape (B, 768). If patch embeddings are
requested, the shape is (B, 16 * 16 * 768), one embedding per patch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we decide to use the unsqueezed shape.

Suggested change
requested, the shape is (B, 16 * 16 * 768), one embedding per patch.
requested, the shape is (B, 6*16*16, 768), one embedding per patch.

Comment on lines 898 to 911
if self.hparams.output_patch_embeddings:
# Take the mean of the embeddings along the group dimension
# excluding the last two latlon_ and time_ embeddings. This
# results in one embedding per patch.
embeddings_raw = rearrange(
embeddings_raw[:, :-2, :], "b (g l) s -> b g (l s)", l=256, g=6
)
embeddings_mean = reduce(embeddings_raw, "b g s -> b s", "mean")
assert embeddings_mean.shape == torch.Size(
[
self.model.encoder.B,
256 * 768,
] # (batch_size, nr of patches * hidden_size)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we squeezing the embeddings to a size like (B, 256*768)?? Can't we just keep the raw embedding shape of (B, 1538, 768), where 1538 is 6 (band groups) * 16 * 16 (patch size)? It will be a lot harder to unsqueeze a 196608 length tensor back to 6*16*16 later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsqueezed would undoubtedly be better, but I assumed that we can only store one-dimensional arrays in a field. I guess from your questions that is not the case. Can we store multidimensional arrays here? If yes, lets store the unsqueezed one for sure!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, GeoParquet/Arrow technically supports FixedShapeTensorArray which can be multi-dimensional.

Copy link
Member Author

@yellowcap yellowcap Jan 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found out that this will require a rewrite without pandas, because pandas can not handle this and implicitly flattens the FixedShapeTensorArray, see example below. I ran across this when running tests and not getting the expected multidimensional arrays. This defintively complicates things in terms of keeping the structure.

import numpy as np
import pandas as pd
import pyarrow as pa

In [15]: array = np.arange(8).reshape((2,2,2))
    ...: 
    ...: arrow = pa.FixedShapeTensorArray.from_numpy_ndarray(array)
    ...: 
    ...: df = pd.DataFrame(data={"arrow": arrow})
    ...: 
    ...: df
    ...: 
Out[15]: 
          arrow
0  [0, 1, 2, 3]
1  [4, 5, 6, 7]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so based on the above, the arrays are outputted in a flat structure for now. I updated the documentation with a sentence about where the structure comes from and why the arrays are flat.

https://github.com/Clay-foundation/model/pull/133/files#diff-2f39d0012d2540cdab1ac1ff882a3875d829b85f0acee68fe102e6b8a4f2c6c0R60

This is not ideal but keeping the multidimensional arrays now would require refactoring of how we construct the geoparquet files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, have you tried inserting a FixedShapeTensorArray into a row cell directly? This seems to work:

import numpy as np
import pandas as pd
import pyarrow as pa

array0 = np.arange(0, 8).reshape((2, 2, 2))
array1 = np.arange(9, 17).reshape((2, 2, 2))

arrow0 = pa.FixedShapeTensorArray.from_numpy_ndarray(array0)
arrow1 = pa.FixedShapeTensorArray.from_numpy_ndarray(array1)

df = pd.DataFrame(data={"embeddings": [arrow0, arrow1]})
print(df)
#                             embeddings
# 0         ([0, 1, 2, 3], [4, 5, 6, 7])
# 1  ([9, 10, 11, 12], [13, 14, 15, 16])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arrow part still flattens the arrays that are input. For this to work, we would have to do the manual deconstruction up to 3 dimensions deep. Not sure that this will make things more user friendly. With this approach, the user would have to reconstruct a list of a list of a list of arrays in the "group" case.

docs/model_embeddings.md Outdated Show resolved Hide resolved
No longer necessary after #135
Copy link
Collaborator

@srmsoumya srmsoumya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the idea of storing embeddings at multiple levels.

@@ -87,19 +87,27 @@ def test_model_vit_fit(datapipe):
@pytest.mark.parametrize(
"litmodule,precision",
[
(CLAYModule, "bf16-mixed" if torch.cuda.is_available() else "32-true"),
(ViTLitModule, "bf16-mixed"),
(CLAYModule, "16-mixed" if torch.cuda.is_available() else "32-true"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is bf16-mixed precision giving any issue while inferencing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for the Nvidia Geforce RTX on my local machine, which does not support this type apparently.

@yellowcap yellowcap merged commit 7a48658 into main Jan 29, 2024
4 checks passed
@yellowcap yellowcap deleted the store-raw-embeddings branch January 29, 2024 21:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update train.py script to optionally store the raw encoder output
3 participants