In [1]:
# Demo using the torchgeo package to initialize a SatlasPretrain model and finetune
# on the UCMerced dataset.
#
# SETUP - this demo requires a DIFFERENT conda environment than the SatlasPretrain demo
# conda create --name torchgeodemo
# conda activate torchgeodemo
# NOTE: Satlas weights will be a part of the 0.6.0 release and the current version is 0.5.1, so install from git for now.
# pip install git+https://github.com/microsoft/torchgeo 

In [3]:
import os
import torch
import tempfile
from typing import Optional
from lightning.pytorch import Trainer

from torchgeo.models import Swin_V2_B_Weights, swin_v2_b
from torchgeo.datamodules import UCMercedDataModule
from torchgeo.trainers import ClassificationTask

In [4]:
# Experiment Arguments
batch_size = 8
num_workers = 2
max_epochs = 10
fast_dev_run = False

In [5]:
# Torchgeo lightning datamodule to initialize dataset
root = os.path.join(tempfile.gettempdir(), "ucm")
datamodule = UCMercedDataModule(
    root=root, batch_size=batch_size, num_workers=num_workers, download=True
)

In [6]:
# Custom ClassificationTask to load in the SatlasPretrain model
class SatlasClassificationTask(ClassificationTask):
    def configure_models(self):
        weights = Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS
        self.model = swin_v2_b(weights)

        # Replace first layer's input channels with the task's number input channels.
        first_layer = self.model.features[0][0]
        self.model.features[0][0] = torch.nn.Conv2d(3,
                                    first_layer.out_channels,
                                    kernel_size=first_layer.kernel_size,
                                    stride=first_layer.stride,
                                    padding=first_layer.padding,
                                    bias=(first_layer.bias is not None))

        # Replace last layer's output features with the number classes.
        self.model.head = torch.nn.Linear(in_features=1024, out_features=self.hparams["num_classes"], bias=True)

In [7]:
# Initialize the Classifcation Task
task = SatlasClassificationTask(num_classes=21)

Downloading: "https://huggingface.co/torchgeo/swin_v2_b_sentinel2_rgb_satlas/resolve/main/swin_v2_b_sentinel2_rgb_satlas-51471041.pth" to /Users/piperw/.cache/torch/hub/checkpoints/swin_v2_b_sentinel2_rgb_satlas-51471041.pth
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 336M/336M [00:14<00:00, 24.6MB/s]


In [8]:
# Initialize the training code.
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
default_root_dir = os.path.join(tempfile.gettempdir(), "experiments")

trainer = Trainer(
    accelerator=accelerator,
    default_root_dir=default_root_dir,
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
    min_epochs=1,
    max_epochs=max_epochs,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
# Train
trainer.fit(model=task, datamodule=datamodule)

Missing logger folder: /var/folders/5f/p350t6y12v73c_wcsm2csn7h0000gp/T/experiments/lightning_logs


Downloading https://cdn-lfs.huggingface.co/repos/fe/cf/fecf1e16b43b78d97916708d99b930db9e863bbfb0a0ea4c5ea038cd1390fa4a/06c539ef28703a58fb07bd2837991ac7c48b813b00bb12ac197efd813a18daeb?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27UCMerced_LandUse.zip%3B+filename%3D%22UCMerced_LandUse.zip%22%3B&response-content-type=application%2Fzip&Expires=1708385017&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwODM4NTAxN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9mZS9jZi9mZWNmMWUxNmI0M2I3OGQ5NzkxNjcwOGQ5OWI5MzBkYjllODYzYmJmYjBhMGVhNGM1ZWEwMzhjZDEzOTBmYTRhLzA2YzUzOWVmMjg3MDNhNThmYjA3YmQyODM3OTkxYWM3YzQ4YjgxM2IwMGJiMTJhYzE5N2VmZDgxM2ExOGRhZWI%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=CJ64TPAJyJrQFk8QW-s0D0r2InD%7EOBa0HH8aJspgnfbgeaw6dZ4eLpO-54qUPmoPRcS2IuzhJEjOotJ9fkuRgMh3tgR2Ms4IeB46OUF3gH1HLeIyRHQsI3Sh1pvoDrPFr4i3IIL8w5FF1ikr1M9XYpJXBU7fWUaF1HS-PWbQPeJTyylm7O8

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 332468434/332468434 [00:07<00:00, 43630613.74it/s]


Downloading https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt to /var/folders/5f/p350t6y12v73c_wcsm2csn7h0000gp/T/ucm/uc_merced-train.txt


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21736/21736 [00:00<00:00, 1274694.03it/s]

Downloading https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt to /var/folders/5f/p350t6y12v73c_wcsm2csn7h0000gp/T/ucm/uc_merced-val.txt



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7201/7201 [00:00<00:00, 4159645.10it/s]


Downloading https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt to /var/folders/5f/p350t6y12v73c_wcsm2csn7h0000gp/T/ucm/uc_merced-test.txt


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7260/7260 [00:00<00:00, 5803439.50it/s]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | SwinTransformer  | 86.9 M
---------------------------------------------------
86.9 M    Trainable params
0         Non-trainable params
86.9 M    Total params
347.709   Total estimated model params size (MB)


Sanity Checking: |                                                                                                                                                   | 0/? [00:00<?, ?it/s]

/Users/piperw/opt/anaconda3/envs/torchgeotest/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                                                                                                                                           

/Users/piperw/opt/anaconda3/envs/torchgeotest/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0:   0%|                                                                                                                                                     | 0/158 [00:00<?, ?it/s]

/Users/piperw/opt/anaconda3/envs/torchgeotest/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
