Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

# Pretrained Weights

In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images.

It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the "Open in Colab" button above to get started.

## Setup

First, we install TorchGeo.

In [1]:
%pip install torchgeo

Collecting torchgeo
  Downloading torchgeo-0.6.1-py3-none-any.whl.metadata (19 kB)
Collecting fiona>=1.8.21 (from torchgeo)
  Downloading fiona-1.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.6/56.6 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kornia>=0.7.3 (from torchgeo)
  Downloading kornia-0.7.4-py2.py3-none-any.whl.metadata (18 kB)
Collecting lightly!=1.4.26,>=1.4.5 (from torchgeo)
  Downloading lightly-1.5.14-py3-none-any.whl.metadata (36 kB)
Collecting lightning!=2.3.*,>=2 (from lightning[pytorch-extra]!=2.3.*,>=2->torchgeo)
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting rasterio<1.4,>=1.3 (from torchgeo)
  Downloading rasterio-1.3.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Collecting rtree>=1 (from torchgeo)
  Downloading Rtree-1.3.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata

## Imports

Next, we import TorchGeo and any other libraries we need.

In [1]:
%matplotlib inline

import os
import tempfile

import timm
import torch
from lightning.pytorch import Trainer

from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask

The following variables can be used to control training.

In [2]:
batch_size = 10
num_workers = 2
max_epochs = 10
fast_dev_run = False

## Datamodule

We will utilize TorchGeo's [Lightning](https://lightning.ai/docs/pytorch/stable/) datamodules to organize the dataloader setup.

In [3]:
root = os.path.join(tempfile.gettempdir(), 'eurosat100')
datamodule = EuroSAT100DataModule(
    root=root, batch_size=batch_size, num_workers=num_workers, download=True
)

## Weights

Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.

To access these weights you can do the following:

In [4]:
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO

This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic.

In [5]:
task = ClassificationTask(
    model='resnet18',
    loss='ce',
    weights=weights,
    in_channels=13,
    num_classes=10,
    lr=0.001,
    patience=5,
)

Downloading: "https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth" to /root/.cache/torch/hub/checkpoints/resnet18_sentinel2_all_moco-59bfdff9.pth
100%|██████████| 42.8M/42.8M [00:00<00:00, 156MB/s]


If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/huggingface/pytorch-image-models) model with pretrained weights from TorchGeo as follows:

In [6]:
in_chans = weights.meta['in_chans']
model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

## Training

To train our pretrained model on the EuroSAT dataset we will make use of Lightning's [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html). For a more elaborate explanation of how TorchGeo uses Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html).

In [7]:
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments')

In [8]:
trainer = Trainer(
    accelerator=accelerator,
    default_root_dir=default_root_dir,
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
    min_epochs=1,
    max_epochs=max_epochs,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model=task, datamodule=datamodule)

Downloading https://cdn-lfs.hf.co/repos/fc/1d/fc1dee780dee1dae2ad48856d0961ac6aa5dfcaaaa4fb3561be4aedf19b7ccc7/2ed4bb4a6808004c98691f64b366827f7783c76a49151e7c2b70423eb77a5b76?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27EuroSAT100.zip%3B+filename%3D%22EuroSAT100.zip%22%3B&response-content-type=application%2Fzip&Expires=1732548716&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMjU0ODcxNn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mYy8xZC9mYzFkZWU3ODBkZWUxZGFlMmFkNDg4NTZkMDk2MWFjNmFhNWRmY2FhYWE0ZmIzNTYxYmU0YWVkZjE5YjdjY2M3LzJlZDRiYjRhNjgwODAwNGM5ODY5MWY2NGIzNjY4MjdmNzc4M2M3NmE0OTE1MWU3YzJiNzA0MjNlYjc3YTViNzY%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=Qu-cmrx5hJLfv7zuUALavHMfuP8IYQAFd93Sv5ttZeg0U0qE6pofu5Q4y%7EJKUaSyvHs%7Ev3QpQyLXOLr9GQKIyjvLINFarE0hbrWBBERqklQRMNOX77vmJcNT2ICTj86KOHSqYfv%7EYegS-TWm7UgTr6V1DCmZu1R%7Eyoj0BYZlnhjOezCiI5hvetsDDMSJ09-Ic3WmbH2fkeGbPLujfY

100%|██████████| 7.72M/7.72M [00:00<00:00, 98.0MB/s]


Downloading https://huggingface.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt to /tmp/eurosat100/eurosat-train.txt


100%|██████████| 1.15k/1.15k [00:00<00:00, 1.40MB/s]


Downloading https://huggingface.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt to /tmp/eurosat100/eurosat-val.txt


100%|██████████| 383/383 [00:00<00:00, 457kB/s]


Downloading https://huggingface.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt to /tmp/eurosat100/eurosat-test.txt


100%|██████████| 382/382 [00:00<00:00, 336kB/s]
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | ResNet           | 11.2 M | train
1 | criterion     | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.852    Total estimated model params size (MB)
110       Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model      

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
