In [2]:
import timm
import os
import torch
import pandas as pd
import numpy as np
import configparser
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Read config file
config = configparser.ConfigParser()
config.read('../config.ini')

DATA_DIR = config['PATHS']['DATA_DIR']

Initiate model. For more info, see https://huggingface.co/torchgeo/ssl4eo_landsat/tree/main

In [4]:
repo_id = "torchgeo/ssl4eo_landsat"
filename = "resnet50_landsat_etm_sr_moco-1266cde3.pth"

# Download the model weights
model_path = hf_hub_download(repo_id=repo_id, filename=filename)

# Create model
state_dict = torch.load(model_path)
model = timm.create_model("resnet50", in_chans=6, num_classes=0)
model.load_state_dict(state_dict)

<All keys matched successfully>

Code for dataloader

In [5]:
class RegressionDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        cluster_id = self.df.iloc[idx]['cluster_id']
        img_path = os.path.join(self.img_dir, cluster_id, 'landsat.np')
        img = np.load(img_path)
        target = (self.df.iloc[idx]['iwi'] / 100)
        if self.transform:
            img = self.transform(img)
        return img, target

In [6]:
landsat_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 0.0000275 - 0.2),
    transforms.Lambda(lambda x: torch.clamp(x, 0.0, 0.3)),
    transforms.Lambda(lambda x: x / 0.3)
])

csv_file = os.path.join(DATA_DIR, 'dhs_with_imgs.csv')
img_dir = os.path.join(DATA_DIR, 'dhs_images')

dataset = RegressionDataset(csv_file=csv_file, img_dir=img_dir, transform=landsat_transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=12, pin_memory=True)

Run inference with SSL4EO model and save the resulting embeddings as a .npy file. This takes about 5 to 10 minutes on my machine.

In [7]:
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

embeddings = []

with torch.no_grad():
    for images, targets in tqdm(dataloader, desc="Running Inference"):
        images = images.to(device, dtype=torch.float32)
        outputs = model(images)
        embeddings.append(outputs.cpu().numpy())

# Concatenate all embeddings into numpy array
embeddings = np.concatenate(embeddings, axis=0)

Running Inference: 100%|██████████| 537/537 [04:25<00:00,  2.02it/s]


Time with num_workers=4: 05:37

Time with num_workers=8: 04:22

Time with num_workers=12: 04:25

Time with num_workers=16: 04:19