-
Notifications
You must be signed in to change notification settings - Fork 473
Add SSL4Eco SeCo-Eco Weights #2849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a new weight configuration for ResNet50 using the SSL4Eco SeCo‐Eco method.
- Introduces a new weight definition (SENTINEL2_ALL_SECO_ECO) in the ResNet50_Weights class.
- Updates associated metadata (e.g., dataset, input channels, publication link, bands) for proper integration with timm and torchvision.
Comments suppressed due to low confidence (2)
torchgeo/models/resnet.py:644
- Verify that the provided publication link is accurate and corresponds to the SeCo-Eco method, and update the documentation if necessary.
'publication': 'https://arxiv.org/abs/2504.18256',
torchgeo/models/resnet.py:647
- Consider adding a comment explaining why 'NDVI' is included in the bands list along with raw satellite bands, as it is typically a derived index.
'bands': ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'NDVI'],
7276028
to
5113d00
Compare
Thank you @isaaccorley @adamjstewart @PlekhanovaElena for the integration 🙏🙏🙏 I have a small concern regarding the transforms. We have a conditional behavior depending on the size of the input image: class SetPatchSizeToPretraingSize(torch.nn.Module):
def __init__(self, patch_size, interpolation='BICUBIC'):
super().__init__()
self.patch_size = patch_size
self.interpolation = interpolation
def forward(self, img):
h, w, _ = img.shape
if h < self.patch_size or w < self.patch_size:
transform = cvtransforms.Resize(
self.patch_size,
interpolation=self.interpolation)
else:
transform = cvtransforms.CenterCrop(self.patch_size)
return transform(img) The proposed kornia-only implementation in TorchGeo would indeed reproduce our behavior for images smaller than _seco_eco_transforms = K.AugmentationSequential(
K.Resize((224, 224)),
... I can create a new PR to fix this but would need your opinion first @isaaccorley @adamjstewart. To the best of my knowledge, kornia does not allow conditional mechanisms. How essential is it that
|
We could start with CenterCrop without padding to get it to <= 224, then use Resize to get it to == 224. Then we wouldn't need any conditional logic. |
I though so too, but CenterCrop does not seem to like input images smaller than the required output crop size: import torch
import kornia.augmentation as K
K.CenterCrop((20, 20))(torch.rand(3, 10, 10))
>>> Traceback (most recent call last):
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-24-0987adfc757d>", line 1, in <module>
K.CenterCrop((20, 20), keepdim=False)(torch.rand(3, 10, 10))
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/kornia/core/module.py", line 311, in __call__
_output_image = decorated_forward(*inputs, **kwargs)
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/kornia/core/module.py", line 81, in wrapper
tensor_outputs = func(*args, **kwargs)
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/kornia/augmentation/base.py", line 251, in forward
params = self.forward_parameters(batch_shape)
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/kornia/augmentation/base.py", line 218, in forward_parameters
_params = self.generate_parameters(torch.Size((int(to_apply.sum().item()), *batch_shape[1:])))
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/kornia/augmentation/_2d/geometric/center_crop.py", line 109, in generate_parameters
return rg.center_crop_generator(batch_shape[0], batch_shape[-2], batch_shape[-1], self.size, self.device)
File "/home/damien/miniconda3/envs/ssl4eco/lib/python3.10/site-packages/kornia/augmentation/random_generator/_2d/crop.py", line 302, in center_crop_generator
raise AssertionError(f"Crop size must be smaller than input size. Got ({height}, {width}) and {size}.")
AssertionError: Crop size must be smaller than input size. Got (10, 10) and (20, 20). Would you maybe have a workaround in mind ? |
You would need to use K.PadTo((224,224)) |
Which, to the best of my knowledge, would in turn not behave as required for images larger than 224: K.PadTo((5, 5))(torch.arange(100).view(10, 10).float()) |
@drprojects Was the pretraining size set to 224x224? Or are these conditional transforms only used for downstream fine-tuning? Doing a search for If we are dead set on using this transform though, I think it would be preferable to upstream this to Kornia directly as it's not geospatial specific and seems like torchvision doesn't have something similar to this either so it would be beneficial for the non-geospatial community as well. |
The pretraining was done on 224x224 images. More precisely, it was performed on random crops resized to 224x224. Hence, the model should in theory be robust to a certain level of stretching (which aligns with using If you think this is fine and that TorchGeo users will adjust their transforms for downstream tasks, then I am also OK with leaving things as is! |
Adds the RN50 weights pretrained on the new SSL4Eco dataset using the SeCo-Eco method modified from the original SeCo method. This is from the CVPR Earthvision paper "SSL4Eco: A Global Seasonal Dataset for Geospatial Foundation Models in Ecology", Plekhanova et al. (2025).
Weights were extracted from the checkpoint and tweaked to be loadable with timm and torchvision resnet50 and rehosted to huggingface here.
cc: @PlekhanovaElena @drprojects @jdollinger-bit