## Run Prithvi

An exercise to do a small but complete analysis from scratch. For this we will

1. Set a location and date range of interest
2. Download Lansat imagery for this specification
3. Load the model checkpoint
4. Prepare data into a format for the model
5. Run the model on the imagery
6. Analyise the model output (embeddings) using PCA

In [None]:
import sys

sys.path.append("..")

In [None]:
import os
import urllib

import geopandas as gpd
import numpy as np
import pandas as pd
import pystac_client
import stackstac
import torch
import yaml
from einops import rearrange, reduce
from huggingface_hub import hf_hub_download
from matplotlib import pyplot as plt
from prithvi import MaskedAutoencoderViT
from rasterio.enums import Resampling
from shapely import Point
from sklearn import decomposition, svm
from sklearn.naive_bayes import MultinomialNB

### Specify location and date of interest
In this example we will use a location in Portugal where a forest fire happened. We will run the model over the time period of the fire and analyse the model embeddings.

In [None]:
# Point over Monchique Portugal
lat, lon = 37.30939, -8.57207

# Dates of a large forest fire
start = "2018-07-01"
end = "2018-09-01"

### Get data from STAC catalog

Based on the location and date we can obtain a stack of imagery using stackstac. Let's start with finding the STAC items we want to analyse.

In [None]:
STAC_API = "https://landsatlook.usgs.gov/stac-server"
COLLECTION = "landsat-c2l2-sr"

# Search the catalogue
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

all_items = search.get_all_items()
all_items
item = all_items[0]
item

# Use S3 links for downloading imagery
for item in all_items:
    for key in item.assets.keys():
        if "alternate" in item.assets[key].extra_fields:
            url = urllib.parse.urlparse(
                item.assets[key].extra_fields["alternate"]["s3"]["href"]
            )
            item.assets[key].href = f"https://{url.netloc}.s3.amazonaws.com{url.path}"
            item.assets[key].href = item.assets[key].extra_fields["alternate"]["s3"][
                "href"
            ]

# Reduce to LS8 and LS9
items = []
dates = []
for item in all_items:
    if item.datetime.date() not in dates:
        if item.id.startswith("LC08") or item.id.startswith("LC09"):
            items.append(item)
            dates.append(item.datetime.date())


print(f"Found {len(items)} items")

### Create a bounding box around the point of interest

This is needed in the projection of the data so that we can generate image chips of the right size.

In [None]:
# Extract coordinate system from first item
epsg = items[0].properties["proj:epsg"]

# Convert point of interest into the image projection
# (assumes all images are in the same projection)
poidf = gpd.GeoDataFrame(
    pd.DataFrame(),
    crs="EPSG:4326",
    geometry=[Point(lon, lat)],
).to_crs(epsg)

coords = poidf.iloc[0].geometry.coords[0]

# Create bounds in projection
size = 224
gsd = 30
bounds = (
    coords[0] - (size * gsd) // 2,
    coords[1] - (size * gsd) // 2,
    coords[0] + (size * gsd) // 2,
    coords[1] + (size * gsd) // 2,
)

### Retrieve the imagery data.

In [None]:
os.environ["AWS_REQUEST_PAYER"] = "requester"

# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use only
# the RGB and NIR bands.
stack = stackstac.stack(
    items,
    bounds=bounds,
    snap_bounds=False,
    epsg=epsg,
    resolution=gsd,
    dtype="float32",
    rescale=False,
    fill_value=0,
    assets=["blue", "green", "red", "nir08", "swir16", "swir22"],
    resampling=Resampling.nearest,
)

print(f"Working with stack of size {stack.shape}")

stack = stack.compute()

stack

### Let's have a look at the imagery we just downloaded

The imagery will contain 7 dates before the fire, of which two are pretty cloudy images. There are also 5 images after the forest fire.

In [None]:
stack.sel(band=["red", "green", "blue"]).plot.imshow(
    row="time", rgb="band", vmin=0, vmax=20000, col_wrap=6
)

### Load the model

We now have the data to analyse, let's load the model.

In [None]:
# Set up config
REPO_ID = "ibm-nasa-geospatial/Prithvi-100M"
CONFIG = "Prithvi_100M_config.yaml"
CHECKPOINT = "Prithvi_100M.pt"
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001

# Download and load configuration and checkpoint
config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG)
config = yaml.safe_load(open(config_path))
checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=CHECKPOINT)

# Initialize the model
model_args = config["model_args"]
model = MaskedAutoencoderViT(**model_args)
model = model.to(DEVICE)
state_dict = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()

model = model.to(DEVICE)

### Convert the band pixel data in to the format for the model

We will take the information in the stack of imagery and convert it into the format that the model requires.

For Prithvi, this means creating input triples. The model takes a short time series of 3 images as input.
Since we do not have a lot of data in this example, we create the triples with overlapping dates. So 
the first triple is for the dates 1 to 3, the second one goes from date 2 to 4, and so on. This leads
to six triples with the data we have here.

In [None]:
mean = config["train_params"]["data_mean"]
std = config["train_params"]["data_std"]

chips = []
for i in range(6):
    chips.append(stack.isel(time=slice(i, i + 3)).values)
chips[0].shape

### Run the model

Pass the datacube we prepared to the model to create embeddings. This will create one embedding vector for each of the images we downloaded.

In [None]:
mean = np.array(mean)
std = np.array(std)

mean = mean[:, None, None]
std = std[:, None, None]

embeddings = []

for ts in chips:
    ts = rearrange(ts, "t c h w -> 1 c t h w")
    ts = ts.astype(np.float32)
    if ts.shape[2] == 3:
        embedding = model.forward_encoder(torch.from_numpy(ts), mask_ratio=0.0)
        cls_embedding = embedding[:, 0, :].detach().cpu().numpy().ravel()
        embeddings.append(cls_embedding)

embeddings = np.array(embeddings)

### Analyse the embeddings

A simple analysis of the embeddings is to reduce each one of them into a single number using Principal Component Analysis. For this we will fit a PCA on the 12 embeddings we have, and do the dimensionality reduction for them. We will se a separation into three groups, the previous images, the cloudy images, and the images after the fire, they all fall into a different range of the PCA space.

In [None]:
# Run PCA
pca = decomposition.PCA(n_components=1)
pca_result = pca.fit_transform(embeddings)

plt.xticks(rotation=-45)

# Plot all points in blue first
plt.scatter(stack.time[:6], pca_result, color="blue")

# # Re-plot cloudy images in green
plt.scatter(stack.time[0], pca_result[0], color="green")

# # Color all images after fire in red
plt.scatter(stack.time[3:6], pca_result[3:6], color="red")

### One embeding per time stamp

Prithvi will also output embeddings for each time step that was passed to it.

In this example we extract the three individual time step embeddings for each
input triple and visualize those using PCA as in the previous example.

In [None]:
mean = np.array(mean)
std = np.array(std)

mean = mean[:, None, None]
std = std[:, None, None]

long_embeddings = []

for ts in chips:
    ts = rearrange(ts, "t c h w -> 1 c t h w")
    ts = ts.astype(np.float32)
    if ts.shape[2] == 3:
        embedding = model.forward_encoder(torch.from_numpy(ts), mask_ratio=0.0)
        cls_embedding = embedding[:, 0, :].detach().cpu().numpy().ravel()
        embedding = rearrange(embedding[:, 1:, :], "1 (t n) d -> 1 t n d", t=3)[0]
        embedding = reduce(embedding, "t n d -> t d", "mean").detach().numpy()
        t0, t1, t2 = embedding
        long_embeddings.extend([t0] + [t1] + [t2])

In [None]:
# Run PCA
pca = decomposition.PCA(n_components=1)
pca_result = pca.fit_transform(long_embeddings)

plt.xticks(rotation=-45)

# Plot all points in blue first
plt.scatter(np.arange(18), pca_result, color="blue")

### And finally, some finetuning

We are going to train a classifier head on the embeddings and use it to detect fires.

In [None]:
# Label the images we downloaded
# 0 = Cloud
# 1 = Forest
# 2 = Fire
labels = np.array([0, 1, 1, 2, 2, 2])

# Split into fit and test manually, ensuring we have all 3 classes in both sets
fit = [0, 1, 3]
test = [2, 4, 5]

# Train a support vector machine model
clf = svm.SVC()
clf.fit(embeddings[fit], labels[fit])

# Predict classes on test set
prediction = clf.predict(embeddings[test])

# Perfect match for SVM
match = np.sum(labels[test] == prediction)
print(f"Matched {match} out of {len(test)} correctly")

In [None]:
# Naive bayes does not learn about the clouds
clf = MultinomialNB()
clf.fit(embeddings[fit] + 100, labels[fit])

# Predict classes on test set
prediction = clf.predict(embeddings[test] + 100)

match = np.sum(labels[test] == prediction)
print(f"Matched {match} out of {len(test)} correctly")