# Check out the Clay v1 model on the dataset

In [3]:
import math

import geopandas as gpd
import numpy as np
import pandas as pd
import pystac_client
import stackstac
import torch
import yaml
from box import Box
from matplotlib import pyplot as plt
from rasterio.enums import Resampling
from shapely import Point
from torchvision.transforms import v2

import shapely
import json

from src.models.clay.model import ClayMAEModule

from src.data.get_satellite_images import ReadSTAC

In [4]:
api_url="https://planetarycomputer.microsoft.com/api/stac/v1"
bands = ['B04', 'B03', 'B02']

stac_reader = ReadSTAC(api_url=api_url)

# Load Dataset

In [5]:
data = gpd.read_file("/workspaces/mine-segmentation/data/raw/mining_tiles_with_masks.gpkg")

In [6]:
# select first row
row = data.iloc[0]
row

tile_id                                                         621
tile_bbox         {"type": "Polygon", "coordinates": [[[-63.3333...
sentinel_2_id     S2B_MSIL2A_20190517T144739_R139_T20NMP_2020100...
source_dataset                                                 tang
timestamp                                2024-06-19 09:07:50.341000
geometry          MULTIPOLYGON (((-63.417205999535305 7.45592197...
Name: 0, dtype: object

In [7]:
# convert geojson string to shapely geometry
geom = shapely.geometry.shape(json.loads(row.tile_bbox))
bounds = geom.bounds

In [8]:
item = stac_reader.get_item_by_name(row.sentinel_2_id, bbox=bounds)
item

In [9]:
stack = stac_reader.get_stack(item, bands)
stack

Loading stack...
Returning stack from single S2 image with ID: S2B_MSIL2A_20190517T144739_R139_T20NMP_20201006T152830


Unnamed: 0,Array,Chunk
Bytes,77.66 MiB,8.00 MiB
Shape,"(3, 1842, 1842)","(1, 1024, 1024)"
Dask graph,12 chunks in 5 graph layers,12 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 77.66 MiB 8.00 MiB Shape (3, 1842, 1842) (1, 1024, 1024) Dask graph 12 chunks in 5 graph layers Data type float64 numpy.ndarray",1842  1842  3,

Unnamed: 0,Array,Chunk
Bytes,77.66 MiB,8.00 MiB
Shape,"(3, 1842, 1842)","(1, 1024, 1024)"
Dask graph,12 chunks in 5 graph layers,12 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


### Load the model

We now have the data to analyse, let's load the model.

In [11]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
ckpt = "https://clay-model-ckpt.s3.amazonaws.com/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt"
torch.set_default_device(device)

model = ClayMAEModule.load_from_checkpoint(
    ckpt, metadata_path="configs/metadata.yaml", shuffle=False, mask_ratio=0
)
model.eval()

model = model.to(device)

Downloading: "https://clay-model-ckpt.s3.amazonaws.com/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt" to /root/.cache/torch/hub/checkpoints/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt
 13%|█▎        | 218M/1.61G [00:59<06:29, 3.85MB/s] 


KeyboardInterrupt: 