Vision-language models
===

In this notebook we use SkyCLIP, a vision-language model trained on Earth imagery, for text-to-image retrieval and zero-shot classification. 

See [this blog post](https://element84.com/machine-learning/towards-a-queryable-earth-with-vision-language-foundation-models/) on how the text-to-image retrieval functionality can be scaled over large regions to make them queryable with natural language.

In [None]:
%env AWS_REQUEST_PAYER=requester

In [None]:
from rastervision.pipeline.file_system.utils import (download_or_copy,
                                                     list_paths)
from rastervision.core.box import Box
from rastervision.pytorch_learner.dataset import (
    SemanticSegmentationSlidingWindowGeoDataset)

from tqdm.auto import tqdm
import torch
from torch.utils.data import ConcatDataset, DataLoader
from shapely.geometry import mapping
import pystac_client
import geopandas as gpd
import contextily as cx
from matplotlib import pyplot as plt
import seaborn as sns

sns.reset_defaults()

DEVICE = 'cuda'

---

# Load SkyCLIP model

In [None]:
!aws s3 cp s3://ml-workshop-internal/vlm/skyclip.pt data/skyclip.pt

In [None]:
import open_clip

model_name = 'ViT-L-14'
model, _, preprocess = open_clip.create_model_and_transforms(model_name)
tokenizer = open_clip.get_tokenizer(model_name)

In [None]:
ckpt_path = 'data/skyclip.pt'
ckpt = torch.load(ckpt_path, map_location=DEVICE)['state_dict']
ckpt = {k[len('module.'):]: v for k, v in ckpt.items()}
message = model.load_state_dict(ckpt)
model = model.cuda().eval()

In [None]:
print('#params: ', f'{sum(p.numel() for p in model.parameters()):,}')

---

# Get imagery

In [None]:
bbox = Box(ymin=39.889060, xmin=-75.104968, ymax=39.989442, xmax=-75.246207)
bbox_polygon = bbox.to_shapely().oriented_envelope
search_geometry = mapping(bbox_polygon)

In [None]:
catalog = pystac_client.Client.open(
    'https://earth-search.aws.element84.com/v1')

items = catalog.search(
    intersects=search_geometry,
    datetime='2019-10-01/2020-01-01',
    collections=['naip'],
).item_collection()
len(items)

In [None]:
gdf = gpd.GeoDataFrame.from_features(items)

In [None]:
fig, ax = plt.subplots()
gdf.plot(ax=ax, ec='k', fc='none')
cx.add_basemap(ax, crs='epsg:4326', source=cx.providers.CartoDB.Voyager)
ax.set_xlabel('longitude')
ax.set_ylabel('latitude')
plt.show()

# Generate vector embeddings

In [None]:
for item in items:
    download_or_copy(item.assets['image'].href, 'data/naip', delete_tmp=True)

In [None]:
img_uris = list_paths('data/naip', ext='tif')
img_uris

In [None]:
dses = [None] * len(img_uris)
for i, uri in enumerate(img_uris):
    dses[i] = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
        image_uri=uri,
        image_raster_source_kw=dict(channel_order=[0, 1, 2]),
        size=400,
        stride=400,
        out_size=224,
    )
ds = ConcatDataset(dses)
len(ds)

In [None]:
dl = DataLoader(ds, batch_size=16, num_workers=4)

In [None]:
EMBEDDING_DIM_SIZE = 768
embs = torch.zeros(len(ds), EMBEDDING_DIM_SIZE)

with torch.inference_mode(), tqdm(dl, desc='Creating chip embeddings') as bar:
    i = 0
    for x, _ in bar:
        x = x.to(DEVICE)
        emb = model.encode_image(x)
        embs[i:i + len(x)] = emb.cpu()
        i += len(x)

# normalize the embeddings
embs /= embs.norm(dim=-1, keepdim=True)

embs.shape

In [None]:
embs_path = 'data/skyclip_naip_embeddings.pt'
torch.save(embs, embs_path)

In [None]:
# !aws s3 cp {embs_path} s3://ml-workshop-internal/2024_05_02/<YOUR NAME>/skyclip_naip_embeddings.pt

# Text-to-image retrieval

In [None]:
windows = sum((_ds.windows for _ds in ds.datasets), [])
len(windows)

In [None]:
def get_chip_scores(text_queries, embs):
    assert len(text_queries) == 1
    text = tokenizer(text_queries)
    with torch.inference_mode():
        text_features = model.encode_text(text.to(DEVICE))
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.cpu()
        chip_scores = torch.cosine_similarity(text_features, embs)
    return chip_scores


def emb_idx_to_chip(i, windows, out_shape=(400, 400)):
    chip, _ = ds[int(i)]
    chip = chip.permute(1, 2, 0)
    return chip


def show_top_chips(chip_scores,
                   windows_df,
                   top_inds=None,
                   nrows=5,
                   ncols=5,
                   figsize=(12, 12),
                   w_pad=-2.5,
                   h_pad=-2.5):
    plt.close('all')
    fig, axs = plt.subplots(nrows, ncols, figsize=figsize)
    fig.tight_layout(w_pad=w_pad, h_pad=h_pad)
    if top_inds is None:
        top_inds = torch.topk(chip_scores, axs.size).indices
    for ax, i in zip(tqdm(axs.flat), top_inds):
        chip = emb_idx_to_chip(i, windows_df)
        ax.imshow(chip)
    for ax in axs.flat:
        ax.axis('off')
    plt.show()

In [None]:
text_query = 'stadium'

chip_scores = get_chip_scores([text_query], embs)
show_top_chips(chip_scores,
               windows,
               nrows=2,
               ncols=4,
               figsize=(12, 6),
               w_pad=-(12 / 4),
               h_pad=-(6 / 4))

# Zero-shot classification

In [None]:
def get_text_scores(text_queries, embs, T):
    assert len(embs) == 1
    text = tokenizer(text_queries)
    with torch.inference_mode():
        text_features = model.encode_text(text.to(DEVICE))
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.cpu()
        out = (embs @ text_features.T)
        out = (out / T).softmax(dim=1)
        out = out.numpy().squeeze()
    return out

In [None]:
img, _ = ds[3583]
img_emb = embs[[3583]]

In [None]:
plt.imshow(img.permute(1, 2, 0))
plt.show()

In [None]:
classes = [
    'forest',
    'harbor',
    'stadium',
    'parking lot',
]

In [None]:
text_scores = get_text_scores(classes, img_emb, T=0.05)

fig, ax = plt.subplots()
ax.bar(classes, text_scores, ec='black')
ax.set_ylim((0, 1))
ax.yaxis.grid(linestyle='--', alpha=1)
ax.set_xlabel('Class')
ax.set_ylabel('Probability')
plt.show()