# Part 3: SatCLIP

SatCLIP is a model that was trained on 1.5B geotagged images from the internet. It uses a contrastive learning approach to learn representations of satellite imagery. You can read more about it in the [paper](https://arxiv.org/abs/2311.17179) or check out the [GitHub repo](https://github.com/microsoft/satclip/).

## Environment Setup

We start by setting up **SatCLIP** code and installing dependencies.

In [None]:
!git clone https://github.com/microsoft/satclip.git # Clone SatCLIP repository

In [4]:
!pip install lightning --quiet
!pip install rasterio --quiet
!pip install torchgeo --quiet
!pip install basemap --quiet

Chose a SatCLIP model from the list of available pretrained models [here](https://github.com/microsoft/satclip#pretrained-models). They all perform somewhat similarly. Let's download a SatCLIP using a ViT16 vision encoder and $L=10$ Legendre polynomials in the location encoder (i.e., a low-resolution SatCLIP).

In [2]:
!wget 'https://satclip.z13.web.core.windows.net/satclip/satclip-resnet18-l40.ckpt'

--2025-02-07 14:16:13--  https://satclip.z13.web.core.windows.net/satclip/satclip-resnet18-l40.ckpt
Resolving satclip.z13.web.core.windows.net (satclip.z13.web.core.windows.net)... 52.239.221.231
Connecting to satclip.z13.web.core.windows.net (satclip.z13.web.core.windows.net)|52.239.221.231|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 75633927 (72M) [application/zip]
Saving to: ‘satclip-resnet18-l40.ckpt’


2025-02-07 14:16:27 (5.23 MB/s) - ‘satclip-resnet18-l40.ckpt’ saved [75633927/75633927]



Load required packages.

In [None]:
import sys
sys.path.append('./satclip/satclip')

import torch
from load import get_satclip

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Automatically select device
print(device)

## Get air temperature dataset

In this example, we work with a dataset of annual mean temperature values recorded at ~3000 locations around the planet introduced [here](https://www.nature.com/articles/sdata2018246). To download the dataset, we can use code from the [PE-GNN repository](https://github.com/konstantinklemmer/pe-gnn/).

In [9]:
from urllib import request
import numpy as np
import pandas as pd
import io
import torch


def get_air_temp_data(pred="temp",norm_y=True,norm_x=True):
  '''
  Download and process the Global Air Temperature dataset (more info: https://www.nature.com/articles/sdata2018246)

  Parameters:
  pred = numeric; outcome variable to be returned; choose from ["temp", "prec"]
  norm_y = logical; should outcome be normalized
  norm_min_val = integer; choice of [0,1], setting whether normalization in range[0,1] or [-1,1]

  Return:
  coords = spatial coordinates (lon/lat)
  x = features at location
  y = outcome variable
  '''
  url = 'https://springernature.figshare.com/ndownloader/files/12609182'
  url_open = request.urlopen(url)
  inc = np.array(pd.read_csv(io.StringIO(url_open.read().decode('utf-8'))))
  coords = inc[:,:2]
  if pred=="temp":
    y = inc[:,4].reshape(-1)
    x = inc[:,5]
  else:
    y = inc[:,5].reshape(-1)
    x = inc[:,4]
  if norm_y==True:
    y = y / y.max()
  if norm_x==True:
    x = x / x.max()

  return torch.tensor(coords), torch.tensor(x), torch.tensor(y)

In [10]:
coords, _, y = get_air_temp_data()

Let's plot our data. Here we show a map of the world with our locations colored by mean temperatures.

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap

fig, ax = plt.subplots(1, figsize=(5, 3))

m = Basemap(projection='cyl', resolution='c', ax=ax)
m.drawcoastlines()
ax.scatter(coords[:,0], coords[:,1], c=y, s=5)
ax.set_title('Annual Mean Temperatures')

## Predictive modeling

We now want to predict annual mean temperature values using SatCLIP embeddings. First, we need to obtain the location embeddings for our dataset from the pretrained model.

In [None]:
satclip_path = 'satclip-resnet18-l40.ckpt'

model = get_satclip(satclip_path, device=device) # Only loads location encoder by default
model.eval()
with torch.no_grad():
  x  = model(coords.double().to(device)).detach().cpu()

We have now collected a 256-dimensional location embedding for each latitude/longitude coordinate in our dataset.

In [None]:
print(coords.shape)
print(x.shape)

Next, let's split our dataset into 50% training and 50% testing data.

In [14]:
from torch.utils.data import TensorDataset, random_split

dataset = TensorDataset(coords, x, y)

train_size = int(0.5 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = random_split(dataset, [train_size, test_size])

coords_train, x_train, y_train = train_set.dataset.tensors[0][train_set.indices], train_set.dataset.tensors[1][train_set.indices], train_set.dataset.tensors[2][train_set.indices]
coords_test, x_test, y_test = test_set.dataset.tensors[0][test_set.indices], test_set.dataset.tensors[1][test_set.indices], test_set.dataset.tensors[2][test_set.indices]

In [None]:
fig, ax = plt.subplots(1, figsize=(5, 3))

m = Basemap(projection='cyl', resolution='c', ax=ax)
m.drawcoastlines()
ax.scatter(coords_train[:,0], coords_train[:,1], c='blue', s=2, label='Training',alpha=0.5)
ax.scatter(coords_test[:,0], coords_test[:,1], c='green', s=2, label='Testing',alpha=0.5)
ax.legend()
ax.set_title('Train-Test Split')

Next we define our prediction model, a simple MLP. (NOTE: SatCLIP embeddings can of course be used with any predictive model, we just opt for an MLP here for simplicity. You can also directly unfreeze the location encoder `model` above and fine-tune it.)

In [16]:
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_dim, dim_hidden, num_layers, out_dims):
        super(MLP, self).__init__()

        layers = []
        layers += [nn.Linear(input_dim, dim_hidden, bias=True), nn.ReLU()] # Input layer
        layers += [nn.Linear(dim_hidden, dim_hidden, bias=True), nn.ReLU()] * num_layers # Hidden layers
        layers += [nn.Linear(dim_hidden, out_dims, bias=True)] # Output layer

        self.features = nn.Sequential(*layers)

    def forward(self, x):
        return self.features(x)

In [None]:
model

Let's set up and run the training loop.

In [None]:
pred_model = MLP(input_dim=256, dim_hidden=64, num_layers=2, out_dims=1).float().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(pred_model.parameters(), lr=0.001)

losses = []
epochs = 3000

for epoch in range(epochs):
  optimizer.zero_grad()
  # Forward pass
  y_pred = pred_model(x_train.float().to(device))
  # Compute the loss
  loss = criterion(y_pred.reshape(-1), y_train.float().to(device))
  # Backward pass
  loss.backward()
  # Update the parameters
  optimizer.step()
  # Append the loss to the list
  losses.append(loss.item())
  if (epoch + 1) % 250 == 0:
    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

Let's make predictions on the test set!

In [None]:
with torch.no_grad():
  model.eval()
  y_pred_test = pred_model(x_test.float().to(device))

# Print test loss
print(f'Test loss: {criterion(y_pred_test.reshape(-1), y_test.float().to(device)).item()}')

Let's show the results on a map!

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 3))

m = Basemap(projection='cyl', resolution='c', ax=ax[0])
m.drawcoastlines()
ax[0].scatter(coords_test[:,0], coords_test[:,1], c=y_test, s=5)
ax[0].set_title('True')

m = Basemap(projection='cyl', resolution='c', ax=ax[1])
m.drawcoastlines()
ax[1].scatter(coords_test[:,0], coords_test[:,1], c=y_pred_test.cpu().reshape(-1), s=5)
ax[1].set_title('Predicted')