Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate embeddings from CLAYModule trained with latlon/time encodings #96

Merged
merged 7 commits into from
Jan 12, 2024

Conversation

weiji14
Copy link
Contributor

@weiji14 weiji14 commented Dec 20, 2023

What I am changing

How I did it

  • In the LightningModule's predict_step, implement the logic to do the forward pass and save-to-gpq step

  • Raw embeddings are of shape (1, 1538, 768), and we take the mean of the patch embeddings (1, 1536, 768) which becomes a (1, 768) shape embedding

  • Sample output table would look like this (same as Rename embeddings file to include MGRS code and store GeoTIFF source_url #86):

    source_url date embeddings geometry
    s3://.../.../claytile_*.tif 2021-01-01 [0.1, 0.4, ... x768] POLYGON(...)
    s3://.../.../claytile_*.tif 2021-06-30 [0.2, 0.5, ... x768] POLYGON(...)
    s3://.../.../claytile_*.tif 2021-12-31 [0.3, 0.6, ... x768] POLYGON(...)

TODO in this PR:

  • Implement predict_step to generate gpd.GeoDataFrame table
  • Implement on_predict_epoch_end to merge gpd.GeoDataFrame tables and output to GeoParquet file(s)
  • Add a unit test

TODO in the future:

  • Refactor to reduce duplicated code?
  • Use shared callback between model_vit.py and model_clay.py?
  • Upload GeoParquet embedding files to HuggingFace datasets

How you can test it

  1. Ensure you have access to the 13-band GeoTIFF data files on s3://clay-tiles-02/02/
  2. Download the pretrained model from s3://clay-model-ckpt/v0/mae_epoch-02_val-loss-0.52.ckpt to the checkpoints/ folder.
  3. Run the following in a bash shell:
python trainer.py predict --ckpt_path=checkpoints/mae_epoch-02_val-loss-0.52.ckpt \
                          --trainer.precision=bf16-mixed \
                          --data.data_dir=s3://clay-tiles-02/02/48MYU \
                          --data.batch_size=32 \
                          --data.num_workers=16
  • This should produce a 48MYU_20180813_20210424_v001.gpq file under the data/embeddings/ folder. Sample file (need to unzip): 48MYU_20180813_20210424_v001.gpq.zip

  • Extra configuration options can be found using python trainer.py predict --help

To load the embeddings from the GeoParquet file:

import geopandas as gpd

geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path="48MYU_20180813_20210424_v001.gpq")
assert geodataframe.shape == (823, 4)  # 823 rows, 4 columns
print(geodataframe)
      index                                         source_url        date  \
0         0  s3://clay-tiles-02/02/48MYU/2018-08-13/claytil...  2018-08-13   
1         1  s3://clay-tiles-02/02/48MYU/2018-08-13/claytil...  2018-08-13   
2         2  s3://clay-tiles-02/02/48MYU/2018-08-13/claytil...  2018-08-13   
3         3  s3://clay-tiles-02/02/48MYU/2018-08-13/claytil...  2018-08-13   
4         4  s3://clay-tiles-02/02/48MYU/2018-08-13/claytil...  2018-08-13   
...     ...                                                ...         ...   
1212   1212  s3://clay-tiles-02/02/48MYU/2021-04-24/claytil...  2021-04-24   
1213   1213  s3://clay-tiles-02/02/48MYU/2021-04-24/claytil...  2021-04-24   
1214   1214  s3://clay-tiles-02/02/48MYU/2021-04-24/claytil...  2021-04-24   
1215   1215  s3://clay-tiles-02/02/48MYU/2021-04-24/claytil...  2021-04-24   
1216   1216  s3://clay-tiles-02/02/48MYU/2021-04-24/claytil...  2021-04-24   

                                             embeddings  \
0     [0.013126503, -0.031934112, 0.0054517575, 0.00...   
1     [0.01362492, -0.03131817, 0.005478967, 0.00358...   
2     [0.013637519, -0.03169147, 0.0055654137, 0.003...   
3     [0.013152077, -0.027163014, 0.007045647, 0.000...   
4     [0.007802248, -0.018802581, 0.0039559323, -0.0...   
...                                                 ...   
1212  [-0.0010275859, -0.005840208, -0.0011097308, -...   
1213  [-0.000579659, -0.004794828, -0.001176401, -0....   
1214  [0.00043649378, -0.004590468, -0.0011525226, -...   
1215  [-0.0012016017, -0.002848133, -0.0016901258, -...   
1216  [-0.00092684187, -0.0075725354, -0.0019668005,...   

                                               geometry  
0     POLYGON ((106.85102 -5.47164, 106.85088 -5.425...  
1     POLYGON ((106.89721 -5.47150, 106.89707 -5.425...  
2     POLYGON ((106.94341 -5.47135, 106.94326 -5.425...  
3     POLYGON ((106.98960 -5.47120, 106.98945 -5.424...  
4     POLYGON ((107.03580 -5.47104, 107.03564 -5.424...  
...                                                 ...  
1212  POLYGON ((107.59432 -6.39429, 107.59409 -6.348...  
1213  POLYGON ((107.64057 -6.39406, 107.64033 -6.347...  
1214  POLYGON ((107.68682 -6.39382, 107.68658 -6.347...  
1215  POLYGON ((107.73306 -6.39357, 107.73282 -6.347...  
1216  POLYGON ((107.77931 -6.39333, 107.77906 -6.347...  

[1217 rows x 5 columns]

Related Issues

Towards #3

Output embeddings to a geopandas.GeoDataFrame with columns 'source_url', 'date', 'embeddings', and 'geometry'. Essentially copying and adapting the code from a767164 in #73, but modifying how the encoder's masking is disabled, and how the mean/average of the embeddings is computed over a slice of the raw embeddings.
@weiji14 weiji14 self-assigned this Dec 20, 2023
The output GeoParquet file now has a filename with a format like "{MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq", e.g. "12ABC_20210101_20231231_v001.gpq". Have implemented this in model_vit.py, and copied over the same `on_predict_epoch_end` method to model_clay.py. Also, we are no longer saving out the index column to the GeoParquet file.
Forgot to update the filename in the unit test to conform to the new `{MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq` format. Patches f19cf8f.
Splitting the previous integration test on the neural network model into separate fit and predict unit tests. Only testing the prediction loop of CLAYModule, because training/validating the model might be too much for CPU-based Continuous Integration. Also for testing CLAYModule, we are using 32-true precision instead of bf16-mixed, because `torch.cat` doesn't work with float16 tensors on the CPU, see pytorch/pytorch#100932 (should be fixed with Pytorch 2.2).
@pytest.mark.parametrize(
"litmodule,precision",
[
(CLAYModule, "bf16-mixed" if torch.cuda.is_available() else "32-true"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some torch.cat calls in CLAYModule that don't work when run on CPU with float16 tensors, see pytorch/pytorch#100932. The patch at pytorch/pytorch#96093 to fix this issue is merged already though, so we can remove this if-then statement in the future when Pytorch 2.2 is out. Note that running CLAYModule on CUDA-enabled GPUs should be fine with float16 or bfloat16.

Decided that the index column might be good to keep for now, since it might help to speed up row counts? But we are resetting the index first before saving it. Partially reverts f19cf8f.
After f1439e3, need to ensure that the index column is checked in the output geodataframe.
@weiji14 weiji14 marked this pull request as ready for review January 11, 2024 22:36
@weiji14 weiji14 added this to the v0 Release milestone Jan 12, 2024
@weiji14
Copy link
Contributor Author

weiji14 commented Jan 12, 2024

Still many things that could be improved, such as sharing duplicated code between model_vit.py and model_clay.py, but will merge in to main first for the first release.

@weiji14 weiji14 merged commit 7082c54 into main Jan 12, 2024
2 checks passed
@weiji14 weiji14 deleted the embeddings-there-and-back-again branch January 12, 2024 04:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant