In [45]:
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights
from torchgeo.datasets import EuroSAT
from torchgeo.models.api import list_models, get_model_weights, get_model

import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
from torchvision.transforms import Compose, Normalize
from torchvision.models.feature_extraction import create_feature_extractor

In [46]:
for model_name in list_models():
    print(model_name)
    for weights in get_model_weights(model_name):
        print("\t", weights)
        model = get_model(model_name, weights=weights)

resnet18
	 ResNet18_Weights.SENTINEL2_ALL_MOCO
	 ResNet18_Weights.SENTINEL2_RGB_MOCO
	 ResNet18_Weights.SENTINEL2_RGB_SECO
resnet50
	 ResNet50_Weights.SENTINEL1_ALL_MOCO
	 ResNet50_Weights.SENTINEL2_ALL_MOCO
	 ResNet50_Weights.SENTINEL2_RGB_MOCO
	 ResNet50_Weights.SENTINEL2_ALL_DINO
	 ResNet50_Weights.SENTINEL2_RGB_SECO
vit_small_patch16_224
	 ViTSmall16_Weights.SENTINEL2_ALL_MOCO
	 ViTSmall16_Weights.SENTINEL2_ALL_DINO


In [4]:
device = torch.device("cuda:1")

In [5]:
model = get_model("resnet50", weights=None, in_chans=13).eval().to(device)

In [6]:
model = get_model("resnet50", weights=ResNet50_Weights.SENTINEL2_ALL_MOCO).eval().to(device)

In [49]:
ResNet50_Weights.SENTINEL2_ALL_MOCO.value

Weights(url='https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/resolve/main/resnet50_sentinel2_all_moco.pth', transforms=AugmentationSequential(
  (augs): AugmentationSequential(
    (Resize_0): Resize(Resize(output_size=256, p=1.0, p_batch=1.0, same_on_batch=True, size=256, side=short, resample=bilinear, align_corners=True, antialias=False))
    (CenterCrop_1): CenterCrop(CenterCrop(p=1.0, p_batch=1.0, same_on_batch=True, resample=bilinear, cropping_mode=slice, align_corners=True, size=(224, 224), padding_mode=zeros))
  )
), meta={'dataset': 'SSL4EO-S12', 'in_chans': 13, 'model': 'resnet50', 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'moco'})

In [7]:
num_features = 512
model = create_feature_extractor(model, return_nodes=["global_pool"])

In [38]:
band_means = torch.tensor(
    [
        1354.40546513,
        1118.24399958,
        1042.92983953,
        947.62620298,
        1199.47283961,
        1999.79090914,
        2369.22292565,
        2296.82608323,
        732.08340178,
        12.11327804,
        1819.01027855,
        1118.92391149,
        2594.14080798,
    ]
)

band_stds = torch.tensor(
    [
        245.71762908,
        333.00778264,
        395.09249139,
        593.75055589,
        566.4170017,
        861.18399006,
        1086.63139075,
        1117.98170791,
        404.91978886,
        4.77584468,
        1002.58768311,
        761.30323499,
        1231.58581042,
    ]
)

#band_means = band_means[[3,2,1]]
#band_stds = band_stds[[3,2,1]]

min_value = (band_means - 2 * band_stds).unsqueeze(1).unsqueeze(2)
max_value = (band_means + 2 * band_stds).unsqueeze(1).unsqueeze(2)

norm = Normalize(band_means, band_stds)

# def preprocess(sample):
#     img = sample["image"].float()
#     img = (img - min_value) / (max_value - min_value)
#     sample["image"] = torch.clip(img, 0, 1)
#     return sample

def preprocess(sample):
    sample["image"] = (sample["image"].float() / 10000.0)
    return sample


train_ds = EuroSAT(
    root="data/EuroSAT/",
    split="train",
    bands=EuroSAT.BAND_SETS["all"],
    transforms=preprocess,
)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=False, num_workers=6)


test_ds = EuroSAT(
    root="EuroSAT/",
    split="test",
    bands=EuroSAT.BAND_SETS["all"],
    transforms=preprocess,
)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=6)

In [39]:
def extract_features(model, dataloader, device):
    x_all = []
    y_all = []

    for batch in tqdm(dataloader):
        images = batch["image"].to(device)
        labels = batch["label"].numpy()
        
        with torch.inference_mode():
            features = model(images)['global_pool'].cpu().numpy()
        
        x_all.append(features)
        y_all.append(labels)

    x_all = np.concatenate(x_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)

    return x_all, y_all

In [40]:
x_train, y_train = extract_features(model, train_dl, device)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169/169 [00:06<00:00, 27.85it/s]


In [41]:
x_test, y_test = extract_features(model, test_dl, device)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169/169 [00:05<00:00, 29.49it/s]


In [42]:
linear_model = LogisticRegression(C=50.0, max_iter=1000)
linear_model.fit(x_train, y_train)

In [43]:
linear_model.score(x_test, y_test)

0.9829629629629629