In [1]:
import os
import torch
import timm

import terratorch
from terratorch.tasks import ClassificationTask, PixelwiseRegressionTask, SemanticSegmentationTask

from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, GeoDataset, UnionDataset
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler,GeoSampler,RandomBatchGeoSampler

import terratorch.models.backbones.prithvi_vit as prithvi_vit

from terratorch.datamodules import GenericNonGeoSegmentationDataModule
from dataset_original import create_dataset
import pandas as pd
from torch.utils.data import DataLoader, ConcatDataset, RandomSampler

from terratorch.datamodules import GenericNonGeoClassificationDataModule

import tqdm
import rasterio as rio
import numpy as np
import pandas as pd

from rasterio.enums import Resampling

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Example of how to create a model using timm 
model = timm.create_model(
    "prithvi_vit_100", pretrained_cfg=dict(file="Prithvi_EO_V1_100M.pt"), pretrained=True, num_classes=0, in_chans=6,
)

print(model)

INFO:root:Model bands not passed. Assuming bands are ordered in the same way as [<HLSBands.BLUE: 'BLUE'>, <HLSBands.GREEN: 'GREEN'>, <HLSBands.RED: 'RED'>, <HLSBands.NIR_NARROW: 'NIR_NARROW'>, <HLSBands.SWIR_1: 'SWIR_1'>, <HLSBands.SWIR_2: 'SWIR_2'>].            Pretrained patch_embed layer may be misaligned with current bands
INFO:timm.models._builder:Loading pretrained weights from file (Prithvi_EO_V1_100M.pt)
INFO:timm.models._helpers:Loaded  from checkpoint 'Prithvi_EO_V1_100M.pt'


TemporalViTEncoder(
  (patch_embed): PatchEmbed(
    (proj): Conv3d(6, 768, kernel_size=(1, 16, 16), stride=(1, 16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (d

In [3]:
print(timm.list_pretrained())
print(help(terratorch.tasks.ClassificationTask))

['bat_resnext26ts.ch_in1k', 'beit_base_patch16_224.in22k_ft_in22k', 'beit_base_patch16_224.in22k_ft_in22k_in1k', 'beit_base_patch16_384.in22k_ft_in22k_in1k', 'beit_large_patch16_224.in22k_ft_in22k', 'beit_large_patch16_224.in22k_ft_in22k_in1k', 'beit_large_patch16_384.in22k_ft_in22k_in1k', 'beit_large_patch16_512.in22k_ft_in22k_in1k', 'beitv2_base_patch16_224.in1k_ft_in1k', 'beitv2_base_patch16_224.in1k_ft_in22k', 'beitv2_base_patch16_224.in1k_ft_in22k_in1k', 'beitv2_large_patch16_224.in1k_ft_in1k', 'beitv2_large_patch16_224.in1k_ft_in22k', 'beitv2_large_patch16_224.in1k_ft_in22k_in1k', 'botnet26t_256.c1_in1k', 'caformer_b36.sail_in1k', 'caformer_b36.sail_in1k_384', 'caformer_b36.sail_in22k', 'caformer_b36.sail_in22k_ft_in1k', 'caformer_b36.sail_in22k_ft_in1k_384', 'caformer_m36.sail_in1k', 'caformer_m36.sail_in1k_384', 'caformer_m36.sail_in22k', 'caformer_m36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k_384', 'caformer_s18.sail_in1k', 'caformer_s18.sail_in1k_384', 'caformer_s

## Classification - Finetune Prithvi to act as a classification model

In [4]:
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands

VIT_UPERNET_NECK = [
    {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
    {"name": "ReshapeTokensToImage"},
    {"name": "LearnedInterpolateToPyramidal"},
]


model_args = {
        "in_channels": 13,
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
        # "bands": ('B02', 'B03', 'B04','B08','B011', 'B012', 'B013'),
        "backbone_pretrained_cfg_overlay":{"file": "C:/Users/alhst/Documents/AI Master/Urban Computing/Project/Prithvi/Files/Prithvi_EO_V1_100M.pt"}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "pretrained":False,
        "num_classes": 4,
        "necks":  VIT_UPERNET_NECK
}

task = ClassificationTask(
    model_args=model_args,
    model_factory="PrithviModelFactory",
    # pretrained_cfg=dict(file="Prithvi_EO_V1_100M.pt"),
    loss="ce",
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=True,
    class_names=["Fossil Brown coal", "Fossil Gas", "Fossil Hard coal", "Fossil Peat"]
)

# bins (classification)
# sloop activation function eruit
# investigate custom head??? Baseclass?

## Training model

Ignore cel below (adapted from original paper)

In [7]:
path = os.getcwd()  # current path
datadir = os.path.join(path, "data\images\images")
reg_file = os.path.join(path, "data\labels.csv")
seglabeldir = os.path.join(path, "data\segmentation_labels\segmentation_labels")

reg_data = pd.read_csv(reg_file)
# create dataset
data_train_120x120 = create_dataset(datadir=os.path.join(datadir, 'training/120x120/'),
                                    seglabeldir=os.path.join(seglabeldir, 'training/120x120/'),
                                    reg_data=reg_data, mult=4, train=True, channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
data_train_300x300 = create_dataset(datadir=os.path.join(datadir, 'training/300x300/'),
                                    seglabeldir=os.path.join(seglabeldir, 'training/300x300/'),
                                    reg_data=reg_data, mult=4, train=True, channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], size=300)

# data_train = ConcatDataset([data_train_120x120, data_train_300x300])
data_train = data_train_300x300

train_sampler = RandomSampler(data_train, replacement=True, num_samples=int(2 * len(data_train) / 3))

# initialize data loaders
train_dl = DataLoader(data_train, batch_size=32, num_workers=6,
                        pin_memory=True, sampler=train_sampler)

print(data_train_300x300.imgfiles)  # printing information from MutiTaskDataset class


['c:\\Users\\alhst\\Documents\\AI Master\\Urban Computing\\Project\\Prithvi\\Files\\data\\images\\images\\training/300x300/positive\\0000__S2A-MSIL2A-ST20200104T110726-N0213-R094-T30UYV-20200104T122020.tif'
 'c:\\Users\\alhst\\Documents\\AI Master\\Urban Computing\\Project\\Prithvi\\Files\\data\\images\\images\\training/300x300/positive\\0000__S2A-MSIL2A-ST20200206T111719-N0214-R137-T30UYV-20200206T122704.tif'
 'c:\\Users\\alhst\\Documents\\AI Master\\Urban Computing\\Project\\Prithvi\\Files\\data\\images\\images\\training/300x300/positive\\0000__S2A-MSIL2A-ST20200317T111723-N0214-R137-T30UYV-20200317T123526.tif'
 ...
 'c:\\Users\\alhst\\Documents\\AI Master\\Urban Computing\\Project\\Prithvi\\Files\\data\\images\\images\\training/300x300/positive\\0298__S2B-MSIL2A-ST20210408T101634-N0300-R022-T33UVS-20210408T132617.tif'
 'c:\\Users\\alhst\\Documents\\AI Master\\Urban Computing\\Project\\Prithvi\\Files\\data\\images\\images\\training/300x300/positive\\0298__S2B-MSIL2A-ST20210428T101632

Part 1: Constructing custom RasterDataset to transform our data to useable format

Ignore cel below (Sentinel class that works with black-box datamodule, previous implementation). Yields batch size mismatch error.

In [57]:
from urllib.parse import urlparse

import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader,default_collate

from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples,GeoDataset
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler,GeoSampler,RandomBatchGeoSampler
from torchgeo.datamodules import GeoDataModule

import re
from typing import cast
import pandas as pd
 
%matplotlib inline
plt.rcParams['figure.figsize'] = (6, 6)

class Sentinel2(RasterDataset):
    filename_glob = '*.tif'
    #198_2019-01-31T10_06_36.654Z_1.tif
    # filename_regex = r'^.{6}_(?P<date>\d{8}T\d{6})_(?P<band>B0[\d])'
    # date_format = '%Y%m%dT%H%M%S'
    is_image = True
    separate_files = True
    all_bands = tuple([f'B0{i}' for i in range(1,14)])
    rgb_bands = ('B04', 'B03', 'B02')
    # bands = ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013')
    bands = ('B01')

    def plot(self, sample):
        # Find the correct band index order
        rgb_indices = []
        for band in self.rgb_bands:
            rgb_indices.append(self.all_bands.index(band))

        # Reorder and rescale the image
        image = sample['image'][rgb_indices].permute(1, 2, 0)
        image = torch.clamp(image / 10000, min=0, max=1).numpy()

        #         indices = []
        # for band in self.all_bands:
        #     indices.append(self.all_bands.index(band))

        # # Reorder and rescale the image
        # image = sample['image'][indices].permute(1, 2, 0)

        # Plot the image
        fig, ax = plt.subplots()
        ax.imshow(image)

        return fig
    
    def __getitem__(self,query):
        #dit is de source code van RasterDataSet.__getitem__() tot waar ik sample['test'] toevoeg
        #kan een work-around zijn?
        hits = self.index.intersection(tuple(query), objects=True)
        filepaths = cast(list[str], [hit.object for hit in hits])
        labels = pd.read_csv("data/labels.csv")


        if not filepaths:
            raise IndexError(
                f'query: {query} not found in index with bounds: {self.bounds}'
            )

        if self.separate_files:
            data_list: list[torch.Tensor] = []
            # labels_list = []
            filename_regex = re.compile(self.filename_regex, re.VERBOSE)
            for band in self.bands:
                band_filepaths = []
                for filepath in filepaths:
                    filename = os.path.basename(filepath)
                    directory = os.path.dirname(filepath)
                    match = re.match(filename_regex, filename)
                    if match:
                        if 'band' in match.groupdict():
                            start = match.start('band')
                            end = match.end('band')
                            filename = filename[:start] + band + filename[end:]
                    filepath = os.path.join(directory, filename)
                    band_filepaths.append(filepath)
                    # extracting labels
                    image_data = labels[labels["filename"] == filepath]
                    # labels_list = torch.Tensor(image_data["fuel_type"].values*13)  # transform to tensor for training, times 13 for each training band
                    # labels_list = torch.Tensor(image_data["fuel_type"])
                # labels_list = torch.Tensor([0])  # for testing
                # print(f'Nr. files in band {band}: {len(filepath)}')  # verify that each image is available in each band
                data_list.append(self._merge_files(band_filepaths, query))
            data = torch.cat(data_list)
        else:
            data = self._merge_files(filepaths, query, self.band_indexes)

        sample = {'crs': self.crs, 'bounds': query}

        labels_list = torch.Tensor([0]*169)
        data = data.to(self.dtype)
        if self.is_image:
            sample['image'] = data
        else:
            sample['mask'] = data

        #hier evt eigen dingen toevoegen aan de sample
        # sample['test'] = 1
        sample['label'] = labels_list

        if self.transforms is not None:
            sample = self.transforms(sample)

        print(self.bands)  # 13 bands, 
        print(sample['image'].shape)  # 169 images of 120x120
        print(sample['label'].shape)  # should have 13 labels? One for each band?
        # print(sample['mask'].shape)  # commented out, becuase does not exist yet

        return sample
        # return sample['image'], sample['label']
    
root = os.path.join(os.getcwd(), "data/images/images/training/120x120/positive")
dataset = Sentinel2(root)
torch.manual_seed(1)    
sampler = RandomGeoSampler(dataset,size=120,length=1) 
dataloader = DataLoader(dataset, sampler=sampler,collate_fn=stack_samples) 

### Part 0: Data preprocessing

In [5]:
default_transform = rio.transform.from_bounds(0, 0, 120, 120, width=120, height=120)

path = os.getcwd()  # current path

# transforming data to allow terratorch to use it
for split in ['training', 'validation']:  # 'validation'
    datadir = os.path.join(path, "data\\images\\images")
    reg_file = os.path.join(path, "data\\labels.csv")
    seglabeldir = os.path.join(path, "data\\segmentation_labels\\segmentation_labels")
    # examples = pd.read_csv(dataset_dir / f'{split}.csv')
    examples = pd.read_csv(reg_file)  # extracting data from .csv
    labels = examples['filename'].values
    labels = [l.replace(":","_") for l in labels]  # reformatting filenames in labels.csv
    examples['filename'] = labels
    print(labels)

    # Create class dirs
    # Expected structure by the TerraTorch generic classification dataset, see:
    # from terratorch.datasets import GenericNonGeoClassificationDataset
    # from terratorch.datamodules import GenericNonGeoClassificationDataModule
    class_names = examples['fuel_type'].unique()  # defining classes 
    for class_name in class_names:
        os.makedirs(f'{path}/data/{split}/{class_name}', exist_ok=True)  # create new directories for train-validation sets

    # for i, row in tqdm.tqdm(examples.iterrows(), total=len(examples), desc=split, maxinterval=10):
    if split == 'training':
        # file_root = os.path.join(datadir, split, '120x120')  # only focus on the 120x120 in training for now
        file_root = os.path.join(datadir, split)  # use all images
    else: 
        file_root = os.path.join(datadir, split)  # only 120x120 images in validation


    for dirpath,dirnames,files in os.walk(file_root):
        for file in files:
            if not file.endswith('.tif'):  # dont need to add file if not a tiff
                continue
            filepath = os.path.join(dirpath,file)
            print(file)
    # for i, file in enumerate(examples['filename']):
        # loading all bands
        # positive_file = f'{datadir}/{split}/120x120/positive/{file}'
        # negative_file = f'{datadir}/{split}/120x120/negative/{file}'
        # if os.path.isfile(positive_file): # or not negative_file.exists():

            # if 'positive' in file:
            #     # print(f"Missing {file}")
            #     print(f"File is positive")
            #     # continue
            #     # with rio.open(positive_file) as src:
            #     with rio.open(file) as src:
            #         load_file = src.read()

            # # elif os.path.isfile(negative_file):
            # elif 'negative' in file:
            #     print(f"File is negative")

            #     # with rio.open(negative_file) as src:
            #     with rio.open(file) as src:
            #         load_file = src.read()
            # # ir = np.load(ir_file)
            # else:
            #     load_file=None
            
            with rio.open(filepath) as src:  # open file to write to new directory
                # if "300" in filepath:  # rescaling 300x300 images to 120x120
                #     load_file = src.read(
                #         out_shape=(
                #         src.count,
                #         120,
                #         120
                #         ),
                #     resampling=Resampling.bilinear
                #     )
                # else: 
                #     load_file = src.read(
                #         out_shape=(
                #         src.count,
                #         224,
                #         224
                #         ),
                #     resampling=Resampling.bilinear
                #     )
                load_file = src.read(
                    out_shape=(
                    src.count,
                    224,
                    224
                    ),
                resampling=Resampling.bilinear
                )
                # load_file = src.read()

            # Stack bands (to match Prithvi channels)
            # BLUE, GREEN, RED, NIR_NARROW, SWIR_1, SWIR_2 / Landsat: B02, B03, B04, B05, B06, B07
            # Note that you don't have to match the channels, you can also define them in the config.
            # stacked = np.concatenate([rgb[[2, 1, 0]], ir.transpose((2, 0, 1))], axis=0)
            # stacked = np.concatenate([band for band in ])

            # Save images in dedicated folder per class
            # out_file = dataset_dir / split / row['label'] / f'{split}_{i}.tif'
            if np.isin(file,labels):  # checking if file is present in .csv
                file_csv = examples[examples['filename']==os.path.basename(os.path.normpath(file))]
                # file_label = examples.loc[i,['fuel_type']].values
                file_label = file_csv['fuel_type'].values[0]  # extract label
                # print(file_csv)
                file_index = file_csv.index[0]  # extract .csv index
                # out_file = f'data/{split}/{file_label}/{split}_{i}.tif'
                out_file = f'data/{split}/{file_label}/{split}_{file_index}.tif'  # new filepath
                print(out_file)
                with rio.open(out_file, 
                            'w',
                            driver='GTiff',
                            width=224,  # changed from 120
                            height=224,  # changed from 120
                            dtype=load_file.dtype,
                            transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
                            count=13) as dst:
                    dst.write(load_file)  # writing
            # with rio.open(out_file,
            #                 'w',
            #                 driver='GTiff',
            #                 height=332,
            #                 width=332,
            #                 dtype=stacked.dtype,
            #                 transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
            #                 count=6) as dst:
            #     dst.write(stacked)

    # if delete_examples:
    #     shutil.rmtree(dataset_dir / 'examples')


# if __name__ == '__main__':
#     main(delete_examples=False)

['0000__S2B-MSIL2A-ST20200122T111720-N0213-R137-T30UYV-20200122T122946.tif', '0046__S2B-MSIL2A-ST20200113T104630-N0213-R008-T32ULC-20200113T112959.tif', '0002__S2B-MSIL2A-ST20200122T111720-N0213-R137-T30UYV-20200122T122946.tif', '0000__S2A-MSIL2A-ST20200104T110726-N0213-R094-T30UYV-20200104T122020.tif', '0046__S2B-MSIL2A-ST20200113T104631-N0213-R008-T31UGT-20200113T112959.tif', '0057__S2A-MSIL2A-ST20200217T104629-N0214-R008-T32ULC-20200217T121511.tif', '0057__S2B-MSIL2A-ST20200212T104630-N0214-R008-T32ULC-20200213T134833.tif', '0000__S2A-MSIL2A-ST20200206T111719-N0214-R137-T30UYV-20200206T122704.tif', '0046__S2A-MSIL2A-ST20200207T104628-N0214-R008-T31UGT-20200207T122428.tif', '0046__S2A-MSIL2A-ST20200217T104630-N0214-R008-T31UGT-20200217T121511.tif', '0057__S2A-MSIL2A-ST20200207T104628-N0214-R008-T31UGT-20200207T122428.tif', '0042__S2A-MSIL2A-ST20200207T104627-N0214-R008-T32ULC-20200207T122428.tif', '0057__S2A-MSIL2A-ST20200207T104627-N0214-R008-T32ULC-20200207T122428.tif', '0046__S2A-

### Part 1: Defining datamodule for lightning trainer

In [14]:

# # Timm requires 224x224 input.
# train_transforms = albumentations.Compose([
#     albumentations.RandomCrop(height=224, width=224),
#     albumentations.HorizontalFlip(),
#     albumentations.pytorch.transforms.ToTensorV2(),
# ])
# val_transforms = albumentations.Compose([
#     albumentations.CenterCrop(height=224, width=224),
#     albumentations.pytorch.transforms.ToTensorV2(),
# ])

# means=[
#     960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
#     1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
#     1914.3, 1512.0585, 1512.0585
#     ]  # updated from dataset_multitask file, full
means=[1110.9012, 1250.0942, 1259.5178, 2251.6265, 1512.0585, 1512.0585]

# stds=[
#     1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
#     1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
#     1303.83, 1189.9052, 1189.9052
#     ]  # updated from dataset_multitask file 
stds=[1418.4988, 1381.5366, 1406.7112, 1604.1998, 1189.9052, 1189.9052]

datamodule = GenericNonGeoClassificationDataModule(
    batch_size=16,
    num_workers=0,
    train_data_root=os.path.join(path, 'data', 'training'),
    val_data_root=os.path.join(path, 'data', 'validation'),
    test_data_root=os.path.join(path, 'data', 'validation'), 
    means=means,
    stds=stds,
    num_classes=4,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transforms,
    # val_transform=val_transforms,
    # test_transform=val_transforms,

    # Bands of your dataset (in this case similar to the model bands)
    dataset_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    # Input bands of your model
    # output_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    output_bands=('B02', 'B03', 'B04','B08','B012', 'B013'),
    constant_scale=39.216,  # Scale 0-255 data to 0-10000 (HLS data) (10000 / 255 = 39.216)
    no_data_replace=0,
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

### Part 2: Defining classification error

In [15]:
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands

VIT_UPERNET_NECK = [
    {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
    {"name": "ReshapeTokensToImage"},
    {"name": "LearnedInterpolateToPyramidal"},
]


model_args = {
        "in_channels": 13,
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        # "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),  # full
        "bands": ('B02', 'B03', 'B04','B08','B012', 'B013'),
        "backbone_pretrained_cfg_overlay":{"file": "C:/Users/alhst/Documents/AI Master/Urban Computing/Project/Prithvi/Files/Prithvi_EO_V1_100M.pt"}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "pretrained":False,
        "num_classes": 4,
        "necks":  VIT_UPERNET_NECK
}

task = ClassificationTask(
    model_args=model_args,
    model_factory="PrithviModelFactory",
    # pretrained_cfg=dict(file="Prithvi_EO_V1_100M.pt"),
    loss="ce",  # need cross-entropy for 
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=True,
)

### Part 4: Training the model - Initialising and fitting lightning trainer

In [16]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir='output', name='tutorial')

# You can also log directly to WandB
# from lightning.pytorch.loggers import WandbLogger
# wandb_logger = WandbLogger(log_model="all") 

trainer = Trainer(
    devices=1, # Number of GPUs. Interactive mode recommended with 1 device
    precision="16-mixed",
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        early_stopping_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=1, # train only one epoch for demo
    default_root_dir='output/tutorial',
    log_every_n_steps=1,
    check_val_every_n_epoch=1
)
_ = trainer.fit(model=task, datamodule=datamodule)

c:\Users\alhst\anaconda3\envs\terratorch\Lib\site-packages\lightning\pytorch\trainer\connectors\accelerator_connector.py:556: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


### Part 5: Testing the finetuned model

In [18]:
res = trainer.test(model=task, datamodule=datamodule)

c:\Users\alhst\anaconda3\envs\terratorch\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


# Segmentation - Predicting segmentation cloud maps

In [None]:
default_transform = rio.transform.from_bounds(0, 0, 120, 120, width=120, height=120)

path = os.getcwd()  # current path

# transforming data to allow terratorch to use it
for split in ['training', 'validation']:  # 'validation'
    datadir = os.path.join(path, "data\\images\\images")
    reg_file = os.path.join(path, "data\\labels.csv")
    seglabeldir = os.path.join(path, "data\\segmentation_labels\\segmentation_labels")
    # examples = pd.read_csv(dataset_dir / f'{split}.csv')
    examples = pd.read_csv(reg_file)  # extracting data from .csv
    labels = examples['filename'].values
    labels = [l.replace(":","_") for l in labels]  # reformatting filenames in labels.csv
    examples['filename'] = labels
    print(labels)

    # Create class dirs
    # Expected structure by the TerraTorch generic classification dataset, see:
    # from terratorch.datasets import GenericNonGeoClassificationDataset
    # from terratorch.datamodules import GenericNonGeoClassificationDataModule
    class_names = examples['fuel_type'].unique()  # defining classes 
    for class_name in class_names:
        os.makedirs(f'{path}/data/{split}/{class_name}', exist_ok=True)  # create new directories for train-validation sets

    # for i, row in tqdm.tqdm(examples.iterrows(), total=len(examples), desc=split, maxinterval=10):
    if split == 'training':
        file_root = os.path.join(datadir, split, '120x120')  # only focus on the 120x120 in training for now
    else: 
        file_root = os.path.join(datadir, split)  # only 120x120 images in validation


    for dirpath,dirnames,files in os.walk(file_root):
        for file in files:
            if not file.endswith('.tif'):  # dont need to add file if not a tiff
                continue
            filepath = os.path.join(dirpath,file)
            print(file)
    # for i, file in enumerate(examples['filename']):
        # loading all bands
        # positive_file = f'{datadir}/{split}/120x120/positive/{file}'
        # negative_file = f'{datadir}/{split}/120x120/negative/{file}'
        # if os.path.isfile(positive_file): # or not negative_file.exists():

            # if 'positive' in file:
            #     # print(f"Missing {file}")
            #     print(f"File is positive")
            #     # continue
            #     # with rio.open(positive_file) as src:
            #     with rio.open(file) as src:
            #         load_file = src.read()

            # # elif os.path.isfile(negative_file):
            # elif 'negative' in file:
            #     print(f"File is negative")

            #     # with rio.open(negative_file) as src:
            #     with rio.open(file) as src:
            #         load_file = src.read()
            # # ir = np.load(ir_file)
            # else:
            #     load_file=None
            
            with rio.open(filepath) as src:  # open file to write to new directory
                load_file = src.read()

            # Stack bands (to match Prithvi channels)
            # BLUE, GREEN, RED, NIR_NARROW, SWIR_1, SWIR_2 / Landsat: B02, B03, B04, B05, B06, B07
            # Note that you don't have to match the channels, you can also define them in the config.
            # stacked = np.concatenate([rgb[[2, 1, 0]], ir.transpose((2, 0, 1))], axis=0)
            # stacked = np.concatenate([band for band in ])

            # Save images in dedicated folder per class
            # out_file = dataset_dir / split / row['label'] / f'{split}_{i}.tif'
            if np.isin(file,labels):  # checking if file is present in .csv
                file_csv = examples[examples['filename']==os.path.basename(os.path.normpath(file))]
                # file_label = examples.loc[i,['fuel_type']].values
                file_label = file_csv['fuel_type'].values[0]  # extract label
                print(file_csv)
                file_index = file_csv.index[0]  # extract .csv index
                # out_file = f'data/{split}/{file_label}/{split}_{i}.tif'
                out_file = f'data/{split}/{file_label}/{split}_{file_index}.tif'  # new filepath
                print(out_file)
                with rio.open(out_file, 
                            'w',
                            driver='GTiff',
                            width=120,
                            height=120,
                            dtype=load_file.dtype,
                            transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
                            count=13) as dst:
                    dst.write(load_file)  # writing
            # with rio.open(out_file,
            #                 'w',
            #                 driver='GTiff',
            #                 height=332,
            #                 width=332,
            #                 dtype=stacked.dtype,
            #                 transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
            #                 count=6) as dst:
            #     dst.write(stacked)

    # if delete_examples:
    #     shutil.rmtree(dataset_dir / 'examples')


# if __name__ == '__main__':
#     main(delete_examples=False)

In [9]:
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands

VIT_UPERNET_NECK = [
    {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
    {"name": "ReshapeTokensToImage"},
    {"name": "LearnedInterpolateToPyramidal"},
]

model_seg_args = {
        "in_channels": 13,
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
        "backbone_pretrained_cfg_overlay":{"file": "C:/Users/alhst/Documents/AI Master/Urban Computing/Project/Prithvi/Files/Prithvi_EO_V1_100M.pt"}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "pretrained":False,
        "num_classes": 2,
        "necks":  VIT_UPERNET_NECK
}

task = SemanticSegmentationTask(
    model_args=model_seg_args,
    model_factory="PrithviModelFactory",
    loss="ce",
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=True
)

In [None]:
means=[
    960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
    1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
    1914.3, 1512.0585, 1512.0585
    ]  # updated from dataset_multitask file

stds=[
    1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
    1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
    1303.83, 1189.9052, 1189.9052
    ]  # updated from dataset_multitask file 

datamodule_seg = GenericNonGeoClassificationDataModule(
    batch_size=16,
    num_workers=0,
    train_data_root=os.path.join(path, 'data', 'training'),
    val_data_root=os.path.join(path, 'data', 'validation'),
    test_data_root=os.path.join(path, 'data', 'validation'),  # reusing the validation set for testing
    img_grep=
    means=means,
    stds=stds,
    num_classes=4,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transforms,
    # val_transform=val_transforms,
    # test_transform=val_transforms,

    # Bands of your dataset (in this case similar to the model bands)
    dataset_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    # Input bands of your model
    output_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    constant_scale=39.216,  # Scale 0-255 data to 0-10000 (HLS data) (10000 / 255 = 39.216)
    no_data_replace=0,
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

# Custom Regression Head

In [108]:
# Trying regression
from regression_tasks import RegressionTask

model_args = {
        "in_channels": 13,
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
        "backbone_pretrained_cfg_overlay":{"file": "C:/Users/alhst/Documents/AI Master/Urban Computing/Project/Prithvi/Files/Prithvi_EO_V1_100M.pt"}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "pretrained":False,
        "num_classes": 1,
        "necks":  VIT_UPERNET_NECK
}

task = RegressionTask(
    model_args=model_args,
    model_factory="PrithviModelFactory",
    # pretrained_cfg=dict(file="Prithvi_EO_V1_100M.pt"),
    loss="mse",  # need cross-entropy for 
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=True,
)

KeyError: 'PrithviModelFactory'

# Ignore all cells below

In [46]:
for batch in dataloader:
    sample = unbind_samples(batch)
    # print(sample[0]['image'])
    indices = []
    bands = tuple([f'B0{i}' for i in range(1,14)])
    for band in bands:
        indices.append(bands.index(band))

    # Reorder and rescale the image
    # image = sample[0]['image'][indices].permute(1, 2, 0)
    image = sample[0]['image'][indices]
    print(image.shape)

NameError: name 'dataloader' is not defined

In [None]:
from terratorch.datasets import HLSBands

batch_size = 1
num_workers = 0
# train_val_test = [
#     "burn_scar_segmentation_toy/train_images",
#     "burn_scar_segmentation_toy/val_images",
#     "burn_scar_segmentation_toy/test_images",
# ]
train_val_test = [
    os.path.join(path, "data\images\images", 'training/300x300/'),
    os.path.join(path, "data\images\images", 'validation/300x300/'),
    os.path.join(path, "data\images\images", 'validation/300x300/')
]

# train_val_test_labels = {
#     "train_label_data_root": "burn_scar_segmentation_toy/train_labels",
#     "val_label_data_root": "burn_scar_segmentation_toy/val_labels",
#     "test_label_data_root": "burn_scar_segmentation_toy/test_labels",
# }
train_val_test_labels = {
    "train_label_data_root": "burn_scar_segmentation_toy/train_labels",
    "val_label_data_root": "burn_scar_segmentation_toy/val_labels",
    "test_label_data_root": "burn_scar_segmentation_toy/test_labels",
}  # still to edit



# from https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/configs/burn_scars.py

means=[
    960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
    1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
    1914.3, 1512.0585
    ]  # updated from dataset_multitask file

stds=[
    1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
    1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
    1303.83, 1189.9052
    ]  # updated from dataset_multitask file 

# datamodule = GenericNonGeoSegmentationDataModule(
#     batch_size,
#     num_workers,
#     *train_val_test,
#     "*_merged.tif", # img grep
#     "*.mask.tif", # label grep
#     means,
#     stds,
#     2, # num classes
#     **train_val_test_labels,

#     # if transforms are defined with Albumentations, you can pass them here
#     # train_transform=train_transform,
#     # val_transform=val_transform,
#     # test_transform=test_transform,

#     # edit the below for your usecase
#     dataset_bands=[
#         HLSBands.BLUE,
#         HLSBands.GREEN,
#         HLSBands.RED,
#         HLSBands.NIR_NARROW,
#         HLSBands.SWIR_1,
#         HLSBands.SWIR_2,
#     ],
#     output_bands=[
#         HLSBands.BLUE,
#         HLSBands.GREEN,
#         HLSBands.RED,
#         HLSBands.NIR_NARROW,
#         HLSBands.SWIR_1,
#         HLSBands.SWIR_2,
#     ],
#     no_data_replace=0,
#     no_label_replace=-1,
# )

datamodule = GenericNonGeoSegmentationDataModule(
    batch_size,
    num_workers,
    *train_val_test,
    img_grep="*_merged.tif", # img grep
    label_grep="*.mask.tif", # label grep
    test_data_root="_merged.tif",
    test_label_data_root=".mask.tif",
    means=means,
    stds=stds,
    num_classes=2, # num classes
    **train_val_test_labels,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transform,
    # val_transform=val_transform,
    # test_transform=test_transform,

    # edit the below for your usecase
    dataset_bands=[
        HLSBands.BLUE,
        HLSBands.GREEN,
        HLSBands.RED,
        HLSBands.NIR_NARROW,
        HLSBands.SWIR_1,
        HLSBands.SWIR_2,
    ],
    output_bands=[
        HLSBands.BLUE,
        HLSBands.GREEN,
        HLSBands.RED,
        HLSBands.NIR_NARROW,
        HLSBands.SWIR_1,
        HLSBands.SWIR_2,
    ],
    no_data_replace=0,
    no_label_replace=-1,
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

TypeError: terratorch.datamodules.generic_pixel_wise_data_module.GenericNonGeoSegmentationDataModule() got multiple values for keyword argument 'test_label_data_root'

### Part 2: Defining Trainer and Custom Dataloader

In [82]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from torchgeo.samplers import GridGeoSampler
from torchgeo.datasets.splits import random_bbox_assignment

checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir='output', name='tutorial')

# You can also log directly to WandB
# from lightning.pytorch.loggers import WandbLogger
# wandb_logger = WandbLogger(log_model="all") 

trainer = Trainer(
    devices=1, # Number of GPUs. Interactive mode recommended with 1 device
    precision="16-mixed",
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        early_stopping_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=1, # train only one epoch for demo
    default_root_dir='output/test',
    log_every_n_steps=1,
    check_val_every_n_epoch=1
)

# for batch_idx, batch in enumerate(train_dl.keys()):
#     # print(batch)
#     print(batch)

class CustomGeoDataModule(GeoDataModule):  # defining a custom datamodule to feed it to the trainer
    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        self.dataset = self.dataset_class(**self.kwargs)
        
        generator = torch.Generator().manual_seed(0)
        (
            self.train_dataset,
            self.val_dataset,
            self.test_dataset,
        ) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)  # not sure what this does yet BUT IT IS VERY NECESSARY
        # Creating train-val-test split???
        
        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(
                self.train_dataset, self.patch_size, self.batch_size, self.length
            )
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(
                self.val_dataset, self.patch_size, self.patch_size
            )
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(
                self.test_dataset, self.patch_size, self.patch_size
            )

custom_datamodule = CustomGeoDataModule(type(dataset), batch_size=2, patch_size=120, length=1)  # runtime error perhaps due to num_workers, could try 0 if commenting this out doesn't work (parallell resources)
custom_datamodule.setup("fit")
# custom_datamodule = GeoDataModule(type(dataset), batch_size=1, patch_size=120, length=1, num_workers=6)  # previous module, doesn't work (gives "split" error)

c:\Users\alhst\anaconda3\envs\terratorch\Lib\site-packages\lightning\pytorch\trainer\connectors\accelerator_connector.py:556: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


NameError: name 'GeoDataModule' is not defined

In [59]:
_ = trainer.fit(model=task, train_dataloaders=custom_datamodule)

ValueError: Expected input batch_size (26) to match target batch_size (2).

In [None]:
print(f"The model was pretrained on bands {task._timm_module.pretrained_bands}.\n The model is using bands {model._timm_module.model_bands}")