In [1]:
# clay model repo on branch clay-v0.2-run
# https://github.com/Clay-foundation/model/tree/clay-v0.2-run
CLAY_MODEL_SRC = "../../../model"
CKPT_PATH = "/opt/data/models/clay-model-v0.2-last.ckpt"
import sys; sys.path.append(CLAY_MODEL_SRC)
from src.model_clay import CLAYModule

# check we are in correct branch

clay model repo on branch `clay-v0.2-run` https://github.com/Clay-foundation/model/tree/clay-v0.2-run


In [2]:
pwd = !pwd
pwd=pwd[0]
%cd $CLAY_MODEL_SRC
clayrepo_branch = !git rev-parse --abbrev-ref HEAD
clayrepo_branch = clayrepo_branch[0]
%cd $pwd
clay_repo_branch = 'clay-v0.2-run'
if clayrepo_branch != clay_repo_branch:
    raise ValueError(f"must switch to branch {clay_repo_branch} on clay model repo")

/home/ubuntu/model
/home/ubuntu/earth-text/notebooks/models


# load model

In [3]:
import torch
from einops import rearrange, reduce, repeat


In [4]:
CKPT_PATH = "/opt/data/models/clay-model-v0.2-last.ckpt"

m = CLAYModule.load_from_checkpoint(
    CKPT_PATH,
    mask_ratio=0.0,
    #band_groups={"rgb": (2, 1, 0)},
    #band_groups={"rgb": (2, 1, 0), "nir": (3,)},
    band_groups={"rgb": (2, 1, 0), "nir": (3,), "sar": (4,5)},
    #band_groups={"rgb": (2, 1, 0), "nir": (3,), "sar": (4,5), 'rededge': (6,7,8,9)},
    bands=4,
    strict=False,  # ignore the extra parameters in the checkpoint
    embeddings_level="mean",
)

/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:188: Found keys that are not in the model state dict but in the checkpoint: ['model.encoder.patch_embedding.rededge.proj.weight', 'model.encoder.patch_embedding.rededge.proj.bias', 'model.encoder.patch_embedding.rededge.norm.weight', 'model.encoder.patch_embedding.rededge.norm.bias', 'model.encoder.patch_embedding.swir.proj.weight', 'model.encoder.patch_embedding.swir.proj.bias', 'model.encoder.patch_embedding.swir.norm.weight', 'model.encoder.patch_embedding.swir.norm.bias', 'model.encoder.patch_embedding.dem.proj.weight', 'model.encoder.patch_embedding.dem.proj.bias', 'model.encoder.patch_embedding.dem.norm.weight', 'model.encoder.patch_embedding.dem.norm.bias', 'model.decoder.embed_to_pixels.rededge.weight', 'model.decoder.embed_to_pixels.rededge.bias', 'model.decoder.embed_to_pixels.swir.weight', 'model.decoder.embed_to_pixels.swir.bias', 'model.decoder.embed_to_pixels.dem.weight', 'model.decoder.embed_to_pixe

# loop over all files batching them and predict

In [34]:
import xarray as xr
import os
import numpy as np
from progressbar import progressbar as pbar
from time import time
import pickle

basedir = "/opt/data/clay-california-worldcover-rgbnir-vvvh-chips/chips"
embeddings_dir = "/opt/data/clay-california-worldcover-rgbnir-vvvh-chips/embeddings_v0.2"
patch_embeddings_dir = "/opt/data/clay-california-worldcover-rgbnir-vvvh-chips/patch_embeddings_v0.2"
files = os.listdir(basedir)

batch_size = 2
batch = []

t0 = time()

for batchnb in pbar(range(0,len(files),batch_size)):

    # load btch of images
    batchfiles = files[batchnb:batchnb+batch_size]
    batch = []
    for fname in batchfiles:
        
        with xr.open_dataarray(f"{basedir}/{fname}") as z:
            img = z.data.copy()
        batch.append(img)

    # prepare data structure for model
    z = { 'pixels': torch.tensor(np.r_[batch]).cuda(),
          'timestep': torch.tensor([[0., 0., 0.]] * batch_size).cuda(),
          'latlon': torch.tensor([[0.,0.]] * batch_size).cuda()
        }

    # run model
    embeddings_raw, _, _, _ =  m.model.encoder(z)

    # compute patch and image embeddings
    patch_embeddings_per_group = rearrange(
        embeddings_raw[:, :-2, :], "b (g h w) d -> b g h w d", w=16, h=16, g=len(m.model.band_groups)
    )
    patch_embeddings = reduce(
        patch_embeddings_per_group, "b g h w d -> b h w d", "mean"
    )
    image_embeddings = reduce(
        patch_embeddings, "b h w d -> b d", "mean"
    )

    # save embeddings
    for i,fname in enumerate(batchfiles):
        dest_fname = fname.split(".")[0]+".pkl"
        #print (batchnb, i, dest_fname)
        with open(f"{patch_embeddings_dir}/{dest_fname}", "wb") as f:
            pickle.dump(patch_embeddings[i].cpu().detach().numpy(), f)
            
        with open(f"{embeddings_dir}/{dest_fname}", "wb") as f:
            pickle.dump(image_embeddings[i].cpu().detach().numpy(), f)

    # empty cuda memory
    del image_embeddings
    del patch_embeddings
    del patch_embeddings_per_group
    del z
    torch.cuda.empty_cache()
    #if batchnb>10:
    #    break
t1 = time()
print (t1-t0)

[38;2;0;255;0m100%[39m [38;2;0;255;0m(55850 of 55850)[39m |##################| Elapsed Time: 1:00:15 Time:  1:00:152955


3615.498753786087


In [14]:
# random generation of data
nbands = len([i for j in m.model.band_groups.values() for i in j])
z = { 'pixels': torch.rand((batch_size,nbands,256,256)),
      'timestep': torch.tensor([[0., 0., 0.]] * batch_size),
      'latlon': torch.tensor([[-124.,43.]] * batch_size)
    }
