In [1]:
#entire package imports
import os
import torch
import timm
import terratorch
from urllib.parse import urlparse
import matplotlib.pyplot as plt
import re
from typing import cast

#sub-imports
from terratorch.tasks import ClassificationTask, PixelwiseRegressionTask
import terratorch.models.backbones.prithvi_vit as prithvi_vit

from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples,GeoDataset
from torchgeo.datasets.splits import random_bbox_assignment
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler,GeoSampler,RandomBatchGeoSampler,GridGeoSampler
from torchgeo.datamodules import GeoDataModule

from torch.utils.data import DataLoader,default_collate

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger


  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.22 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


First we specifify the model and create a task.

In [2]:
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands

VIT_UPERNET_NECK = [
    {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
    {"name": "ReshapeTokensToImage"},
    {"name": "LearnedInterpolateToPyramidal"},
]


model_args = {
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        "bands": pretrained_bands,
        # "backbone_pretrained_cfg_overlay":{"file": "C:/Users/alhst/Documents/AI Master/Urban Computing/Project/Prithvi/Files/Prithvi_EO_V1_100M.pt"}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "backbone_pretrained_cfg_overlay":{"file": "C:Prithvi_EO_V1_100M.pt"}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "pretrained":False,
        "num_classes": 4,
        "necks":  VIT_UPERNET_NECK
}

task = ClassificationTask(
    model_args=model_args,
    model_factory="PrithviModelFactory",
    # pretrained_cfg=dict(file="Prithvi_EO_V1_100M.pt"),
    loss="ce",
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=True)


In [10]:
if os.getcwd().endswith('UCProjectGroup1'): #from Tim: UCProjectGroup1 is the github repo and sometimes VSCode thinks my cwd is there
    os.chdir('..')
print(os.getcwd())

c:\Users\timvd\Documents\Uni_2024-2025\UC\Project\ProjectCode


Next, we define our dataset object based on torchgeo rasterdataset.

In [11]:
class Sentinel2(RasterDataset):
    filename_glob = '*.tif'
    #198_2019-01-31T10_06_36.654Z_1.tif
    # filename_regex = r'^.{6}_(?P<date>\d{8}T\d{6})_(?P<band>B0[\d])'
    # date_format = '%Y%m%dT%H%M%S'
    is_image = True
    separate_files = True
    all_bands = tuple([f'B0{i}' for i in range(1,14)])
    rgb_bands = ('B04', 'B03', 'B02')

    def plot(self, sample):
        # Find the correct band index order
        rgb_indices = []
        for band in self.rgb_bands:
            rgb_indices.append(self.all_bands.index(band))

        # Reorder and rescale the image
        image = sample['image'][rgb_indices].permute(1, 2, 0)
        image = torch.clamp(image / 10000, min=0, max=1).numpy()

        # Plot the image
        fig, ax = plt.subplots()
        ax.imshow(image)

        return fig
    
    def __getitem__(self,query):
        #dit is de source code van RasterDataSet.__getitem__() tot waar ik sample['test'] toevoeg
        #kan een work-around zijn?
        hits = self.index.intersection(tuple(query), objects=True)
        filepaths = cast(list[str], [hit.object for hit in hits])

        if not filepaths:
            raise IndexError(
                f'query: {query} not found in index with bounds: {self.bounds}'
            )

        if self.separate_files:
            data_list: list[torch.Tensor] = []
            filename_regex = re.compile(self.filename_regex, re.VERBOSE)
            for band in self.bands:
                band_filepaths = []
                for filepath in filepaths:
                    filename = os.path.basename(filepath)
                    directory = os.path.dirname(filepath)
                    match = re.match(filename_regex, filename)
                    if match:
                        if 'band' in match.groupdict():
                            start = match.start('band')
                            end = match.end('band')
                            filename = filename[:start] + band + filename[end:]
                    filepath = os.path.join(directory, filename)
                    band_filepaths.append(filepath)
                data_list.append(self._merge_files(band_filepaths, query))
            data = torch.cat(data_list)
        else:
            data = self._merge_files(filepaths, query, self.band_indexes)

        sample = {'crs': self.crs, 'bounds': query}

        data = data.to(self.dtype)
        if self.is_image:
            sample['image'] = data
        else:
            sample['mask'] = data

        #hier evt eigen dingen toevoegen aan de sample
        # sample['test'] = 1

        if self.transforms is not None:
            sample = self.transforms(sample)

        return sample

#initialize dataset  
root = os.path.join(os.getcwd(), "data/images/images/training/120x120/positive")
dataset = Sentinel2(root)
torch.manual_seed(1)    
sampler = RandomGeoSampler(dataset,size=120,length=1) 
dataloader = DataLoader(dataset, sampler=sampler,collate_fn=stack_samples) 

In [12]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from torchgeo.samplers import GridGeoSampler
from torchgeo.datasets.splits import random_bbox_assignment

checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir='output', name='tutorial')

# You can also log directly to WandB
# from lightning.pytorch.loggers import WandbLogger
# wandb_logger = WandbLogger(log_model="all") 

trainer = Trainer(
    devices=1, # Number of GPUs. Interactive mode recommended with 1 device
    precision="16-mixed",
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        early_stopping_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=1, # train only one epoch for demo
    default_root_dir='output/test',
    log_every_n_steps=1,
    check_val_every_n_epoch=1
)

# for batch_idx, batch in enumerate(train_dl.keys()):
#     # print(batch)
#     print(batch)

class CustomGeoDataModule(GeoDataModule):  # defining a custom datamodule to feed it to the trainer
    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        self.dataset = self.dataset_class(**self.kwargs)
        
        generator = torch.Generator().manual_seed(0)
        (
            self.train_dataset,
            self.val_dataset,
            self.test_dataset,
        ) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)  # not sure what this does yet BUT IT IS VERY NECESSARY
        
        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(
                self.train_dataset, self.patch_size, self.batch_size, self.length
            )
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(
                self.val_dataset, self.patch_size, self.patch_size
            )
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(
                self.test_dataset, self.patch_size, self.patch_size
            )

custom_datamodule = CustomGeoDataModule(type(dataset), batch_size=1, patch_size=120, length=1)  # runtime error perhaps due to num_workers, could try 0 if commenting this out doesn't work (parallell resources)
# custom_datamodule = GeoDataModule(type(dataset), batch_size=1, patch_size=120, length=1, num_workers=6)  # previous module, doesn't work (gives "split" error)
_ = trainer.fit(model=task, train_dataloaders=custom_datamodule)

c:\Users\timvd\anaconda3\envs\UC-env-2\Lib\site-packages\lightning\pytorch\trainer\connectors\accelerator_connector.py:556: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


RuntimeError: shape '[-1, 6, 1, 120, 120]' is invalid for input of size 2433600