## Install dependencies

In [None]:
!pip install pystac_client==0.6.1 stackstac==0.4.4

In [None]:
!pip install ipywidgets seaborn

---

In [None]:
import os

os.environ['AWS_ACCESS_KEY_ID'] = ''
os.environ['AWS_SECRET_ACCESS_KEY'] = ''
os.environ['AWS_SESSION_TOKEN'] = ''

In [None]:
import gc
from time import perf_counter

from rastervision.core.box import Box
from rastervision.core.data import (
    MinMaxTransformer, RasterioCRSTransformer, 
    StatsTransformer, XarraySource)
from rastervision.core.data.raster_source import XarraySource

from rastervision.pipeline.file_system.utils import download_if_needed, json_to_file, file_to_json
from rastervision.core.evaluation import ensure_json_serializable
from rastervision.core import RasterStats
from rastervision.core.data import Scene
from rastervision.pytorch_learner import (
    SemanticSegmentationRandomWindowGeoDataset,
    SemanticSegmentationSlidingWindowGeoDataset)

import math
from tqdm import tqdm
import numpy as np
from shapely.geometry import mapping
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Normalize
import albumentations as A

from matplotlib import pyplot as plt
import seaborn as sns
sns.reset_defaults()

DEVICE = 'cuda:0'

---

## Get a time-series of Sentinel-2 images from a STAC API

Get Sentinel-2 imagery from 2023-06-01 to 2023-06-20 over Paris, France.

In [None]:
import pystac_client
import pystac
import stackstac

In [None]:
bbox = Box(xmin=67.882, ymin=27.674, xmax=67.929, ymax=27.825)
bbox_geometry = mapping(bbox.to_shapely().oriented_envelope)
bbox_geometry

In [None]:
YEAR = 2018

In [None]:
%%time

URL = "https://earth-search.aws.element84.com/v1"
catalog = pystac_client.Client.open(URL)

items = catalog.search(
    intersects=bbox_geometry,
    collections=["sentinel-2-l2a"],
    datetime=f"{YEAR}-01-01/{YEAR}-12-31",
    query={"eo:cloud_cover": {"lt": 5}},
).item_collection()

len(items)

In [None]:
stack = stackstac.stack(items)
stack

### Convert to a Raster Vision `RasterSource`

In [None]:
crs_transformer = RasterioCRSTransformer(
    transform=stack.transform, image_crs=stack.crs)

In [None]:
data_array = stack
data_array = data_array.sel(
    band=[
        'coastal', # B01
        'blue', # B02
        'green', # B03
        'red', # B04
        'rededge1', # B05
        'rededge2', # B06
        'rededge3', # B07
        'nir', # B08
        'nir08', # B8A
        'nir09', # B09
        'swir16', # B11
        'swir16', # B11
        'swir22', # B12
    ])

In [None]:
valid_ts, _ = np.where(~np.isnan(data_array.isel(x=0, y=0, band=[3]).to_numpy()))
data_array = data_array.isel(time=valid_ts)
data_array

In [None]:
stats_uri = 's3://raster-vision-ahassan/sentinel-2-embeddings/SSL4EO_stats.json'
# stats_uri = 'SSL4EO_stats.json'
stats_tf = StatsTransformer.from_stats_json(stats_uri)

In [None]:
T = len(data_array.time)
t_strs = np.array([str(_t.date()) for _t in data_array.time.to_series().to_list()])

---

## Get model

https://github.com/zhu-xlab/SSL4EO-S12

MoCo	ResNet18	S2-L1C 13 bands

In [None]:
from torch import nn
from torchvision.models import resnet18
from rastervision.pytorch_learner.utils import adjust_conv_channels

In [None]:
sd = torch.load('./B13_rn18_moco_0099_ckpt.pth')
# sd = torch.load(download_if_needed('s3://raster-vision-ahassan/sentinel-2-embeddings/B13_rn18_moco_0099_ckpt.pth', '.'))
sd_encoder_q = {k: v for k, v in sd['state_dict'].items() if (k.startswith('module.encoder_q') and not '.fc.' in k)}
sd_encoder_q_no_prefix = {k.replace('module.encoder_q.', ''): v for k, v in sd_encoder_q.items()}

In [None]:
model = resnet18(weights=None)
model.fc = nn.Identity()
model.conv1 = adjust_conv_channels(model.conv1, 13, pretrained=False)
model.load_state_dict(sd_encoder_q_no_prefix)
model = model.to(device=DEVICE)
model = model.eval()

---

## Run inference

In [None]:
dst_file_s3 = f's3://raster-vision-ahassan/sentinel-2-embeddings/42R_{YEAR}.json'
print(dst_file_s3)
try:
    embeddings = file_to_json(dst_file_s3)
except:
    embeddings = {}
print(len(embeddings))

In [None]:
download_first = False
chip_sz = 500
stride = chip_sz
img_sz = 256
batch_size = 8
num_workers = 0
resize_tf = A.Resize(img_sz, img_sz)

In [None]:
for t in range(T):
    t_str = t_strs[t]
    if t_str in embeddings:
        continue
    print(t_str)
    embeddings_t = []
    raster_source_t = XarraySource(
        data_array.isel(time=t),
        crs_transformer=crs_transformer,
        raster_transformers=[stats_tf],
    )
    if download_first:
        print(f'Downloading data')
        tic = perf_counter()
        raster_source_t.data_array.load()
        toc = perf_counter()
        print(f'Done. ({toc - tic:.1f}s)')
    scene = Scene('', raster_source_t)
    ds = SemanticSegmentationSlidingWindowGeoDataset(
        scene, 
        chip_sz, 
        stride, 
        padding=0, 
        transform=resize_tf, 
        normalize=False,
    )
    dl = DataLoader(ds, batch_size=batch_size, num_workers=num_workers)
    nodata_mask = []
    with torch.inference_mode(), tqdm(dl, desc=t_strs[t]) as bar:
        for x, _ in bar:
            mask = (x.sum(dim=0).reshape(len(x), -1) > 0).all(dim=1)
            nodata_mask += mask.tolist()
            if not mask.any():
                continue
            x = x[mask]
            _x = x.to(device=DEVICE)
            out = model(_x)
            embeddings_t.append(out.cpu().numpy())
    del raster_source_t
    gc.collect()
    nodata_mask = np.array(nodata_mask)
    print('% NODATA chips', nodata_mask.sum())
    windows = np.array(ds.windows)[nodata_mask].tolist()
    embeddings_t = np.concatenate(embeddings_t)
    embeddings[t_str] = dict(
        embeddings=embeddings_t, 
        windows_pixel=windows,
        windows_map=[crs_transformer.pixel_to_map(w) for w in windows]
    )
    print('Uploading to S3')
    tic = perf_counter()
    json_to_file(ensure_json_serializable(embeddings), dst_file_s3)
    toc = perf_counter()
    print(f'Done. ({toc - tic:.1f}s)')
    gc.collect()
    

---