Vision-language models
===

- SkyCLIP

In [None]:
%pip install -q open_clip_torch

In [None]:
%%capture
%env AWS_REQUEST_PAYER=requester
%env AWS_ACCESS_KEY_ID=
%env AWS_SECRET_ACCESS_KEY=
%env AWS_SESSION_TOKEN=

In [None]:
from os.path import join
from glob import glob
import gc

from rastervision.pipeline.file_system.utils import json_to_file
from rastervision.core.box import Box
from rastervision.core.data import RasterioSource, Scene
from rastervision.core.data.utils import geoms_to_geojson
from rastervision.pytorch_learner.dataset import (
    SemanticSegmentationSlidingWindowGeoDataset)

from tqdm.auto import tqdm
import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader
import albumentations as A
from shapely.geometry import mapping
import pystac_client

DEVICE = 'cuda'

---

# Load SkyCLIP model

In [None]:
!aws s3 cp s3://raster-vision-ahassan/qe/SkyCLIP_ViT_L14_top50pct/epoch_20.pt SkyCLIP.pt

In [None]:
import open_clip

model_name = 'ViT-L-14'
model, _, preprocess = open_clip.create_model_and_transforms(model_name)
tokenizer = open_clip.get_tokenizer(model_name)

In [None]:
ckpt_path = 'SkyCLIP.pt'
ckpt = torch.load(ckpt_path, map_location=DEVICE)['state_dict']
ckpt = {k[len('module.'):]:v for k, v in ckpt.items()}
message = model.load_state_dict(ckpt)
model = model.cuda().eval()

---

# Get imagery

In [None]:
bbox = Box(ymin=23.711, xmin=58.1, ymax=23.413, xmax=58.782)
bbox_polygon = bbox.to_shapely().oriented_envelope
search_geometry = mapping(bbox_polygon)

In [None]:
catalog = pystac_client.Client.open('https://earth-search.aws.element84.com/v1')

items = catalog.search(
    intersects=search_geometry,
    datetime='2019-09-18',
    collections=['naip'],
).item_collection()
len(items)

In [None]:
items

# Generate vector embeddings

In [None]:
img_uris = []

In [None]:
rs = RasterioSource(uris=img_uris, channel_order=[0, 1, 2])
rs.shape

In [None]:
from rastervision.pytorch_learner import SemanticSegmentationSlidingWindowGeoDataset

ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
    image_uri=image_uri,
    image_raster_source_kw=dict(channel_order=[0, 1, 2]),
    size=400,
    stride=400,
    out_size=224,
)

In [None]:
dl = DataLoader(ds, batch_size=16, num_workers=4)

In [None]:
# this depends on the model architecture
EMBEDDING_DIM_SIZE = 768

In [None]:
embs = torch.zeros(len(ds), EMBEDDING_DIM_SIZE)

with torch.inference_mode(), tqdm(dl, desc='Creating chip embeddings') as bar:
    i = 0
    for x, _ in bar:
        x = x.to(DEVICE)
        emb = model.encode_image(x)
        embs[i:i + len(x)] = emb.cpu()
        i += len(x)

# normalize the embeddings
embs /= embs.norm(dim=-1, keepdim=True)

embs.shape

In [None]:
embs_path = f'skysclip_naip_MA.pt'
torch.save(embs, embs_path)

In [None]:
# !aws s3 cp {embs_path} s3://...

# Text-to-image retrieval

In [None]:
def get_chip_scores(text_queries, embs):
    assert len(text_queries) == 1
    text = tokenizer(text_queries)
    with torch.inference_mode():
        text_features = model.encode_text(text.to(DEVICE))
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.cpu()
        chip_scores = torch.cosine_similarity(text_features, embs)
    return chip_scores

def emb_idx_to_chip(i, windows, out_shape=(400, 400)):
    chip, _ = ds[int(i)]
    return chip

def show_top_chips(chip_scores, windows_df, top_inds=None, nrows=5, ncols=5, figsize=(12, 12), w_pad=-2.5, h_pad=-2.5):
    plt.close('all')
    fig, axs = plt.subplots(nrows, ncols, figsize=figsize)
    fig.tight_layout(w_pad=w_pad, h_pad=h_pad)
    if top_inds is None:
        top_inds = torch.topk(chip_scores, axs.size).indices
    for ax, i in zip(tqdm(axs.flat), top_inds):
        chip = emb_idx_to_chip(i, windows_df)
        ax.imshow(chip)
    for ax in axs.flat:
        ax.axis('off')
    plt.show()

In [None]:
text_query = 'house with swimming pool'

chip_scores = get_chip_scores([text_query], embs)
show_top_chips(chip_scores, ds.windows, nrows=2, ncols=4, figsize=(12, 6), w_pad=-(12/4), h_pad=-(6/4))

# Zero-shot classification

In [None]:
def get_text_scores(text_queries, embs):
    assert len(embs) == 1
    text = tokenizer(text_queries)
    with torch.inference_mode():
        text_features = model.encode_text(text.to(DEVICE))
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.cpu()
        text_scores = torch.cosine_similarity(embs, text_features)
    return text_scores

In [None]:
img, _ = ds[123]
img_emb = embs[[123]]

In [None]:
classes = [
    'warehouse',
    'forest',
    'harbor',
]

In [None]:
text_scores = get_text_scores(classes, embs)
text_scores