# AnySat Guide

#### Simple Usage

AnySat is available through PyTorch Hub.

In [4]:
import torch

model = torch.hub.load('gastruc/anysat', 'anysat', pretrained=True, force_reload=True, flash_attn=False)

Downloading: "https://github.com/gastruc/anysat/zipball/main" to /home/GAstruc/.cache/torch/hub/main.zip


#### Local usage

Repo installation:

```bash
git clone https://github.com/gastruc/AnySat.git
cd AnySat
pip install -e AnySat
```



In [1]:
from hubconf import AnySat

model = AnySat.from_pretrained('base', flash_attn=False) #Set flash_attn=True if you have flash-attn module installed (url flash attn)
#device = "cuda" If you want to run on GPU default is cpu

#### Experiments Reproduction

All experiments are available in the [experiments](https://github.com/gastruc/AnySat/tree/main/experiments) folder.

For the reproduction of AnySat envirnoment run:

```bash
# clone project
git clone https://github.com/gastruc/anysat
cd anysat

# [OPTIONAL] create conda environment
conda create -n anysat python=3.9
conda activate anysat

# install requirements
pip install -r requirements.txt

# Create data folder where you can put your datasets
mkdir data
# Create logs folder
mkdir logs
```

And Then run the experiment you want:

```bash
# Run AnySat pretraining on GeoPlex
python src/train.py exp=GeoPlex_AnySAT

# Run AnySat finetuning on BraDD-S1TS
python src/train.py exp=BraDD_AnySAT_FT

# Run AnySat linear probing on BraDD-S1TS
python src/train.py exp=BraDD_AnySAT_LP
```

You can modify through hydra all parameters you want. For example to train a Small version of AnySat on GeoPlex datasets, run:

```bash
python src/train.py exp=GeoPlex_AnySAT model=Any_Small_multi
```



## Inference on AnySat

#### Template of data 

We are gonna use an example from TreeSatAI-TS dataset.

In [2]:
import rasterio
import torch
from datetime import datetime
import h5py

def day_number_in_year(date_arr, place=4):
    day_number = []
    for date_string in date_arr:
        date_object = datetime.strptime(str(date_string).split('_')[place][:8], '%Y%m%d')
        day_number.append(date_object.timetuple().tm_yday) # Get the day of the year
    return torch.tensor(day_number)

with rasterio.open('.media/Abies_alba_1_1005_WEFL_NLF.tif') as src:
    aerial = torch.FloatTensor(src.read())[:, 2:302, 2:302]
    
with h5py.File('.media/Abies_alba_1_1005_WEFL_NLF.h5', 'r') as file:
    s1_dates = day_number_in_year(file["sen-1-asc-products"][:])
    s2 = torch.tensor(file["sen-2-data"][:])
    s2_dates = day_number_in_year(file["sen-2-products"][:], place=2)
    
s1 = torch.load('.media/Abies_alba_1_1005_WEFL_NLF.pth')

In [3]:
# Normalize data. AnySat requires data to be normalized.

MEAN_AERIAL = torch.tensor([
        150.89349365234375,
        92.7138900756836,
        84.85437774658203,
        80.70423889160156
    ]).float()
STD_AERIAL = torch.tensor([
        36.764923095703125,
        27.62498664855957,
        22.479450225830078,
        26.733688354492188
    ]).float()
MEAN_S2 = torch.tensor([
        4304.32958984375,
        4159.2666015625,
        4057.776611328125,
        4328.951171875,
        4571.22119140625,
        4644.87109375,
        4837.2470703125,
        4700.2578125,
        2823.264404296875,
        2319.97021484375
    ]).float()
STD_S2 = torch.tensor([
        3537.99755859375,
        3324.23486328125,
        3270.070068359375,
        3250.530029296875,
        2897.391357421875,
        2754.4970703125,
        2821.521484375,
        2625.952392578125,
        1731.56298828125,
        1549.3028564453125
    ]).float()
MEAN_S1 = torch.tensor([
        3.2893013954162598,
        -3.682938814163208,
        0.6116273403167725
    ]).float()
STD_S1 = torch.tensor([
        40.11152267456055,
        40.535335540771484,
        1.0343183279037476
    ]).float()

aerial = (aerial - MEAN_AERIAL[:, None, None]) / STD_AERIAL[:, None, None]
s2 = (s2 - MEAN_S2[:, None, None]) / STD_S2[:, None, None]
s1 = (s1 - MEAN_S1[:, None, None]) / STD_S1[:, None, None]

In [4]:
print("aerial shape", aerial.unsqueeze(0).shape)
print("s2 shape", s2.unsqueeze(0).shape)
print("s2_dates shape", s2_dates.unsqueeze(0).shape)
print("s1 shape", s1.unsqueeze(0).shape)
print("s1_dates shape", s1_dates.unsqueeze(0).shape)

aerial shape torch.Size([1, 4, 300, 300])
s2 shape torch.Size([1, 146, 10, 6, 6])
s2_dates shape torch.Size([1, 146])
s1 shape torch.Size([1, 60, 3, 6, 6])
s1_dates shape torch.Size([1, 60])


To get features from an observation of a batch of observations, you need to provide to the model a dictionnary where keys are from the list: 
| Dataset       | Description                       | Tensor Size                                          | Channels                                  | Resolution |
|---------------|-----------------------------------|-----------------------------------------|-------------------------------------------|------------|
| aerial        | Single date tensor |Bx4xHxW                                              | RGB, NiR                                  | 0.2m       |
| aerial-flair  | Single date tensor |Bx5xHxW                                              | RGB, NiR, Elevation                       | 0.2m       |
| spot          | Single date tensor |Bx3xHxW                                              | RGB                                       | 1m         |
| naip          | Single date tensor |Bx4xHxW                                               | RGB                                       | 1.25m      |
| s2            | Time series tensor |BxTx10xHxW                                          | B2, B3, B4, B5, B6, B7, B8, B8a, B11, B12 | 10m        |
| s1-asc        | Time series tensor |BxTx2xHxW                                             | VV, VH                                     | 10m        |
| s1            | Time series tensor |BxTx3xHxW                                            | VV, VH, Ratio                             | 10m        |
| alos          | Time series tensor |BxTx3xHxW                                            | HH, HV, Ratio                             | 30m        |
| l7            | Time series tensor |BxTx6xHxW                                            | B1, B2, B3, B4, B5, B7                    | 30m        |
| l8            | Time series tensor |BxTx11xHxW                                           | B8, B1, B2, B3, B4, B5, B6, B7, B9, B10, B11 | 10m        |
| modis         | Time series tensor |BxTx7xHxW                                            | B1, B2, B3, B4, B5, B6, B7                | 250m       |

In [5]:
data = {
    "aerial": aerial.unsqueeze(0), #1 batch size, 4 channels, 300x300 pixels
    "s2": s2.unsqueeze(0), #1 batch size, 146 dates, 10 channels, 6x6 pixels
    "s2_dates": s2_dates.unsqueeze(0),
    "s1": s1.unsqueeze(0), #1 batch size, 60 dates, 10 channels, 6x6 pixels
    "s1_dates": s1_dates.unsqueeze(0),
}

Note that time series requires a `_dates` companion tensor containing the day of the year: 01/01 = 0, 31/12=364.

Decide on:
- **Patch size** (in m, must be a multiple of 10): adjust according to the scale of your tiles and GPU memory. In general, avoid having more than 1024 patches per tile.
- **Output type**: Choose between:
  - `'tile'`: Single vector per tile
  - `'patch'`: A vector per patch
  - `'dense'`: A vector per sub-patch. Doubles the size to the vectors
  - `'all'`: A vector per patch with class token at first position
 
The sub patches are `1x1` pixels for time series and `10x10` pixels for VHR images. If using `output='dense'`, specify the `output_modality`.
Scale should divide the spatial cover of all modalities and be a multiple of 10

In [6]:
features = model(data, patch_size=10, output='tile') 
print(features.shape)

torch.Size([1, 768])


In [7]:
features = model(data, patch_size=10, output='patch') 
print(features.shape)

torch.Size([1, 6, 6, 768])


In [8]:
features = model(data, patch_size=20, output='patch') 
print(features.shape)

torch.Size([1, 3, 3, 768])


In [9]:
features = model(data, patch_size=60, output='patch') 
print(features.shape)

torch.Size([1, 1, 1, 768])


In [10]:
features = model(data, patch_size=20, output='dense', output_modality="aerial") 
print(features.shape)

torch.Size([1, 30, 30, 1536])


In [11]:
features = model(data, patch_size=20, output='dense', output_modality="s2") 
print(features.shape)

torch.Size([1, 6, 6, 1536])
