## This file is a copy of "terratorch_implementation.ipynb" for the purpose of working with the regressor.

I deleted ignorable cells from this file to save up on unnecessary clutter.

In [1]:
import os
import torch

from terratorch.tasks import ClassificationTask, PixelwiseRegressionTask, SemanticSegmentationTask
from regression_tasks import RegressionTask #our own custom single value RegressionTask based on ClassificationTask

from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, GeoDataset, UnionDataset
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler,GeoSampler,RandomBatchGeoSampler

import terratorch.models.backbones.prithvi_vit as prithvi_vit

from terratorch.datamodules import GenericNonGeoSegmentationDataModule
# from dataset_original import create_dataset
import pandas as pd
from torch.utils.data import DataLoader, ConcatDataset, RandomSampler

from terratorch.datamodules import GenericNonGeoClassificationDataModule
from RegressionData import GenericNonGeoClassificationRegressionDataModule #our own "custom" single value Regression DataModule

# import tqdm
import rasterio as rio
import numpy as np
import pandas as pd

from rasterio.enums import Resampling

from data_preprocessing import create_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# print(timm.list_pretrained())
# print(help(terratorch.tasks.ClassificationTask))

## Regression - Finetune Prithvi to act as a regression model

## Training model

Part 1: Constructing custom RasterDataset to transform our data to useable format

### Part 0: Data preprocessing

In [2]:
if 'UCProjectGroup1' in os.getcwd(): # For Tim's code to move to right dir
    os.chdir('..')

# Creating train-validation-test split
data_path =  os.path.join(os.getcwd(), "data") 
labels = pd.read_csv(os.path.join(data_path,'labels.csv')) 

#correct label filenames since they are different in the csv
label_filenames = labels['filename'].values
new_label_filenames = [l.replace(":","_") for l in label_filenames]
labels['filename'] = new_label_filenames

train, val, test= create_split(labels)
data_split ={"training":train, "validation":val, "test":test}
print(data_split["training"].shape, data_split["validation"].shape, data_split["test"].shape)

(1705,) (212,) (214,)


Preprocessing the data by re-ordering, for REGRESSION e.g. regression ground truth is added to filename and it is saved in a seperate regressiondata folder.

In [3]:
default_transform = rio.transform.from_bounds(0, 0, 120, 120, width=120, height=120)

if 'UCProjectGroup1' in os.getcwd(): # For Tim's code to move to right dir if need be
    os.chdir('..')
path = os.getcwd()  # current path

if not os.path.isdir('data\\regressiondata'):
    os.mkdir('data\\regressiondata')

In [4]:

# transforming data to allow terratorch to use it
datadir = os.path.join(path, "data\\images\\images")
reg_file = os.path.join(path, "data\\labels.csv")
seglabeldir = os.path.join(path, "data\\segmentation_labels\\segmentation_labels")
# examples = pd.read_csv(dataset_dir / f'{split}.csv')
examples = pd.read_csv(reg_file)  # extracting data from .csv
filenames = examples['filename'].values
filenames = [l.replace(":","_") for l in filenames]  # reformatting filenames in labels.csv
examples['filename'] = filenames
print(filenames)
for split in data_split.keys():  # 'validation'
    # Create split dirs in regressiondata
    os.makedirs(f'{path}/data/regressiondata/{split}',exist_ok=True)
    class_names = examples['fuel_type'].unique()  # defining classes 
    for class_name in class_names:
        os.makedirs(f'{path}/data/regressiondata/{split}/{class_name}', exist_ok=True)  # create new directories for train-validation sets


for dirpath,dirnames,files in os.walk(datadir):
    for file in files:
        if not file.endswith('.tif'):  # dont need to add file if not a tiff
            continue
        filepath = os.path.join(dirpath,file)

        if file in data_split["training"]:
            split = "training"
        elif file in data_split["validation"]:
            split = "validation"
        elif file in data_split["test"]:
            split = "test"
        else: 
            print("File is not in the list")
        print(f'Taking: {file}',end='\t')
        
        with rio.open(filepath) as src:  # open file to write to new directory
            load_file = src.read(
                out_shape=(
                src.count,
                224,
                224
                ),
            resampling=Resampling.bilinear
            )

        # Stack bands (to match Prithvi channels)
        # BLUE, GREEN, RED, NIR_NARROW, SWIR_1, SWIR_2 / Landsat: B02, B03, B04, B05, B06, B07
        # Note that you don't have to match the channels, you can also define them in the config.
        # stacked = np.concatenate([rgb[[2, 1, 0]], ir.transpose((2, 0, 1))], axis=0)
        # stacked = np.concatenate([band for band in ])

        # Save images in dedicated folder per class
        # out_file = dataset_dir / split / row['label'] / f'{split}_{i}.tif'
        if np.isin(file,filenames):  # checking if file is present in labels.csv
            file_csv = examples[examples['filename']==os.path.basename(os.path.normpath(file))] #file specific Series object
            file_regression = file_csv['gen_output'].values[0] # extract regression ground truth
            file_label = file_csv['fuel_type'].values[0]  # extract label e.g. fuel type ground truth
            file_index = file_csv.index[0]  # extract .csv index
            out_file = f'data\\regressiondata\\{split}\\{file_label}\\{split}_{file_index}_{str(file_regression)}.tif'  # new filepath with regression ground truth, index added to make it unique
            if os.path.isfile(out_file): 
                print(f'{out_file} already exists! Skipping.')
                continue 
            print(f'Moving to: {out_file}')
            print()
            with rio.open(out_file, 
                        'w',
                        driver='GTiff',
                        width=224,  # changed from 120
                        height=224,  # changed from 120
                        dtype=load_file.dtype,
                        transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
                        count=13) as dst:
                dst.write(load_file)  # writing

['0000__S2B-MSIL2A-ST20200122T111720-N0213-R137-T30UYV-20200122T122946.tif', '0046__S2B-MSIL2A-ST20200113T104630-N0213-R008-T32ULC-20200113T112959.tif', '0002__S2B-MSIL2A-ST20200122T111720-N0213-R137-T30UYV-20200122T122946.tif', '0000__S2A-MSIL2A-ST20200104T110726-N0213-R094-T30UYV-20200104T122020.tif', '0046__S2B-MSIL2A-ST20200113T104631-N0213-R008-T31UGT-20200113T112959.tif', '0057__S2A-MSIL2A-ST20200217T104629-N0214-R008-T32ULC-20200217T121511.tif', '0057__S2B-MSIL2A-ST20200212T104630-N0214-R008-T32ULC-20200213T134833.tif', '0000__S2A-MSIL2A-ST20200206T111719-N0214-R137-T30UYV-20200206T122704.tif', '0046__S2A-MSIL2A-ST20200207T104628-N0214-R008-T31UGT-20200207T122428.tif', '0046__S2A-MSIL2A-ST20200217T104630-N0214-R008-T31UGT-20200217T121511.tif', '0057__S2A-MSIL2A-ST20200207T104628-N0214-R008-T31UGT-20200207T122428.tif', '0042__S2A-MSIL2A-ST20200207T104627-N0214-R008-T32ULC-20200207T122428.tif', '0057__S2A-MSIL2A-ST20200207T104627-N0214-R008-T32ULC-20200207T122428.tif', '0046__S2A-

In [5]:
# count number of files in each folder after splitting
total = 0
for split in data_split.keys():
    split_path = os.path.join('data\\regressiondata',split)
    file_count = sum(len(files) for _, _, files in os.walk(split_path))
    total += file_count
    print(f"{split}: {file_count} files ({(file_count/2131)*100:.2f}%)")

print(f'Total: {total}, should be 2131')

training: 1705 files (80.01%)
validation: 212 files (9.95%)
test: 214 files (10.04%)
Total: 2131, should be 2131


### JSON to tiff for segmentation

In [None]:
# # Rastering segmentation json files
# import json
# from shapely.geometry import Polygon
# from rasterio.features import rasterize
# import geopandas as gpd
# import matplotlib.pyplot as plt

# path = os.getcwd()  # current path
# data_path =  os.path.join(os.getcwd(), "data") 
# datadir = os.path.join(path, "data\\images\\images")
# seglabeldir = os.path.join(path, "data\\segmentation_labels\\segmentation_labels")

# # create folder to host the new segmentation maps
# os.makedirs(f'{path}/data/labels', exist_ok=True)
# default_transform = rio.transform.from_bounds(0, 0, 120, 120, width=120, height=120)
# seglabels = []
# segfile_lookup = {}

# # Make a lookup table for each segmentation file
# idx = 0
# for dirpath,dirnames,files in os.walk(seglabeldir):
#     for seglabelfile in files:
#         if not os.path.join(dirpath,seglabelfile).endswith(".json"):
#             continue
#         segdata = json.load(open(os.path.join(dirpath,
#                                                 seglabelfile), 'r'))
#         seglabels.append(segdata)
#         segfile_lookup[
#             "-".join(segdata['data']['image'].split('-')[1:]).replace(
#                 '.png', '.tif')] = idx
#         idx+=1

# seglabels_poly = []
# # read in image file names for positive images
# idx = 0
# for root, _, files in os.walk(data_path):
#     for filename in files:
#         if not filename.endswith('.tif'):
#             continue
#         if filename not in segfile_lookup.keys():
#             continue
#         img_path = os.path.join(root, filename)

#         # extracting image size
#         if "120x120" in root:
#             size = 120
#         elif "300x300" in root:
#             size =300
#         else: 
#             print("Outlier size image")

#         polygons = []
#         for completions in seglabels[segfile_lookup[filename]]['completions']:
#             for result in completions['result']:
#                 polygons.append(
#                     np.array(
#                         result['value']['points'] + [result['value']['points'][0]]) * size / 100)
#         with rio.open(img_path, 'r') as src:
#             img_file = src.read()

#     # rasterize segmentation polygons
#         fptdata = np.zeros((img_file.shape[1], img_file.shape[2]), dtype=np.uint8)
#         # polygons = seglabels_poly.copy()
#         shapes = []
#         if len(polygons) > 0: # add polygons if clouds are present in images
#             for pol in polygons:
#                 try:
#                     pol = Polygon(pol)
#                     shapes.append(pol)
#                 except ValueError:
#                     continue
#             polygon_geom = [(g, 1) for g in shapes]
#             fptdata = rasterize([(g, 1) for g in shapes],
#                                 out_shape=fptdata.shape,
#                                 all_touched=True)
            
#             # code to plot the masks alongside 1st band images
#             # fig, ax = plt.subplots()
#             # plt.imshow(img_file[1, :, :])
#             # plt.show()
#             # plt.imshow(fptdata, cmap='grey')
#             # plt.show()

#         # convert raster to tiff 
#         mask_name = f"data/labels/{filename}.tif"
#         with rio.open(mask_name, 
#                     'w',
#                     driver='GTiff',
#                     width=size,
#                     height=size,
#                     dtype=fptdata.dtype,
#                     transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
#                     count=1) as dst:
#             dst.write(fptdata, 1)  # writing
#         idx+=1

### Part 1: Defining datamodule for lightning trainer

In [26]:

# # Timm requires 224x224 input.
# train_transforms = albumentations.Compose([
#     albumentations.RandomCrop(height=224, width=224),
#     albumentations.HorizontalFlip(),
#     albumentations.pytorch.transforms.ToTensorV2(),
# ])
# val_transforms = albumentations.Compose([
#     albumentations.CenterCrop(height=224, width=224),
#     albumentations.pytorch.transforms.ToTensorV2(),
# ])

# means of full 13 bands
# means=[
#     960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
#     1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
#     1914.3, 1512.0585, 1512.0585
#     ]  # updated from dataset_multitask file, full
# means of 6 Prithvi bands
means=[1110.9012, 1250.0942, 1259.5178, 2251.6265, 1512.0585, 1512.0585]

# stds of full 13 bands
# stds=[
#     1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
#     1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
#     1303.83, 1189.9052, 1189.9052
#     ]  # updated from dataset_multitask file 
# stds of 6 Prithvi bands
stds=[1418.4988, 1381.5366, 1406.7112, 1604.1998, 1189.9052, 1189.9052]

datamodule = GenericNonGeoClassificationRegressionDataModule( #changed from classification to our modified one
    batch_size=16,
    num_workers=8,
    train_data_root=os.path.join(path, 'data','regressiondata', 'training'),
    val_data_root=os.path.join(path, 'data','regressiondata', 'validation'),
    test_data_root=os.path.join(path, 'data','regressiondata', 'test'), 
    means=means,
    stds=stds,
    num_classes=1,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transforms,
    # val_transform=val_transforms,
    # test_transform=val_transforms,

    # Bands of your dataset (in this case similar to the model bands)
    dataset_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    # Input bands of your model
    # output_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    output_bands=('B02', 'B03', 'B04','B08','B012', 'B013'),
    constant_scale=39.216,  # Scale 0-255 data to 0-10000 (HLS data) (10000 / 255 = 39.216)
    no_data_replace=0,
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

### Part 2: Defining the task

In [31]:
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands

VIT_UPERNET_NECK = [
    {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
    {"name": "ReshapeTokensToImage"},
    {"name": "LearnedInterpolateToPyramidal"},
]


if 'alhst' in os.getcwd():
    weights_filepath = r"C:/Users/alhst/Documents/AI Master/Urban Computing/Project/Prithvi/Files/Prithvi_EO_V1_100M.pt"
else:
    weights_filepath = r"C:\Users\timvd\Documents\Uni_2024-2025\UC\Project\ProjectCode\Prithvi_EO_V1_100M.pt"

model_args = {
        "in_channels": 13,
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        # "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
        "bands":('B02', 'B03', 'B04','B08','B012', 'B013'),
        "backbone_pretrained_cfg_overlay":{"file": weights_filepath}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "pretrained":False,
        "num_classes": 1, #since univariate regression
        "necks":  VIT_UPERNET_NECK
}

task = RegressionTask(
    model_args=model_args,
    model_factory="PrithviModelFactory",
    # pretrained_cfg=dict(file="Prithvi_EO_V1_100M.pt"),
    loss="mse",
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=True,
    class_names=['Power Output']) #unused anyway

# bins (classification)
# sloop activation function eruit
# investigate custom head??? Baseclass?



In [32]:
print(datamodule.train_dataset)
print(datamodule.batch_size)
print(datamodule.output_bands)

GenericClassificationRegressionDataset Dataset
    type: NonGeoDataset
    size: 1705
16
('B02', 'B03', 'B04', 'B08', 'B012', 'B013')


### Part 4: Training the model - Initialising and fitting lightning trainer

In [33]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir='output', name='tutorial')

# You can also log directly to WandB
# from lightning.pytorch.loggers import WandbLogger
# wandb_logger = WandbLogger(log_model="all") 

trainer = Trainer(
    devices=1, # Number of GPUs. Interactive mode recommended with 1 device
    precision="16-mixed",
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        early_stopping_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=1, # train only one epoch for demo
    default_root_dir='output/tutorial',
    log_every_n_steps=1,
    check_val_every_n_epoch=1
)
_ = trainer.fit(model=task, datamodule=datamodule)

c:\Users\timvd\anaconda3\envs\UC-env\Lib\site-packages\lightning\pytorch\trainer\connectors\accelerator_connector.py:512: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)


INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


RuntimeError: Boolean value of Tensor with more than one value is ambiguous

### Part 5: Testing the finetuned model

In [18]:
res = trainer.test(model=task, datamodule=datamodule)

c:\Users\alhst\anaconda3\envs\terratorch\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


# Segmentation - Predicting segmentation cloud maps

### Processing the segmentation masks into the correct folders

In [42]:
default_transform = rio.transform.from_bounds(0, 0, 120, 120, width=120, height=120)

path = os.getcwd()  # current path
os.makedirs(f'{path}/data/segmentation', exist_ok=True)  # make dir for segmentation data

# transforming data to allow terratorch to use it
for split in data_split.keys():  # 'validation'
    datadir = os.path.join(path, "data\\images\\images")
    reg_file = os.path.join(path, "data\\labels.csv")
    seglabeldir = os.path.join(path, "data\\segmentation_labels\\segmentation_labels")
    # examples = pd.read_csv(dataset_dir / f'{split}.csv')
    examples = pd.read_csv(reg_file)  # extracting data from .csv
    labels = examples['filename'].values
    labels = [l.replace(":","_") for l in labels]  # reformatting filenames in labels.csv
    examples['filename'] = labels
    print(labels)

    # Create class dirs
    class_names = examples['fuel_type'].unique()  # defining classes 
    for class_name in class_names:
        os.makedirs(f'{path}/data/segmentation/{split}/{class_name}', exist_ok=True)  # create new directories for train-validation sets

    # if split == 'training':
    #     # file_root = os.path.join(datadir, split, '120x120')  # only focus on the 120x120 in training for now
    #     file_root = os.path.join(datadir, split)  # use all images
    # else: 
    #     file_root = os.path.join(datadir, split)  # only 120x120 images in validation    

for dirpath,dirnames,files in os.walk(datadir):
    for file in files:
        if not file.endswith('.tif'):  # dont need to add file if not a tiff
            continue
        filepath = os.path.join(dirpath,file)  # image file
        if not '120x120' in filepath:  # try with 120x120 images first
            continue
        segmentationpath = os.path.join(path, 'data', 'labels', file)  #corresponding segmentation file

        if file in data_split["training"]:
            split = "training"
        elif file in data_split["validation"]:
            split = "validation"
        elif file in data_split["test"]:
            split = "test"
        else: 
            print("File is not in the list")
        print(file)
        
        with rio.open(filepath) as src:  # open file to write to new directory
            load_file = src.read(
                out_shape=(
                src.count,
                120,
                120
                )
            )
        with rio.open(filepath) as src:  # open file to write to new directory
            load_file = src.read(
                out_shape=(
                src.count,
                120,
                120
                )
            )

        # Stack bands (to match Prithvi channels)
        # BLUE, GREEN, RED, NIR_NARROW, SWIR_1, SWIR_2 / Landsat: B02, B03, B04, B05, B06, B07
        # Note that you don't have to match the channels, you can also define them in the config.
        # stacked = np.concatenate([rgb[[2, 1, 0]], ir.transpose((2, 0, 1))], axis=0)
        # stacked = np.concatenate([band for band in ])

        # Save images in dedicated folder per class
        # out_file = dataset_dir / split / row['label'] / f'{split}_{i}.tif'
        if np.isin(file,labels):  # checking if file is present in .csv
            file_csv = examples[examples['filename']==os.path.basename(os.path.normpath(file))]
            # file_label = examples.loc[i,['fuel_type']].values
            file_label = file_csv['fuel_type'].values[0]  # extract label
            # print(file_csv)
            file_index = file_csv.index[0]  # extract .csv index
            # out_file = f'data/{split}/{file_label}/{split}_{i}.tif'
            out_file = f'data/segmentation/{split}/{file_label}/{split}_{file_index}.tif'  # new filepath
            seg_out_file = f'data/segmentation/{split}/{file_label}/{split}_{file_index}.mask.tif'  # new filepath
            print(out_file)
            with rio.open(out_file, 
                        'w',
                        driver='GTiff',
                        width=120,  # change all to 120
                        height=120,  # change all to 120
                        dtype=load_file.dtype,
                        transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
                        count=13) as dst:
                dst.write(load_file)  # writing
            with rio.open(seg_out_file, 
                        'w',
                        driver='GTiff',
                        width=120,  # change all to 120
                        height=120,  # change all to 120
                        dtype=load_file.dtype,
                        transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
                        count=13) as dst:
                dst.write(load_file)  # writing

['0000__S2B-MSIL2A-ST20200122T111720-N0213-R137-T30UYV-20200122T122946.tif', '0046__S2B-MSIL2A-ST20200113T104630-N0213-R008-T32ULC-20200113T112959.tif', '0002__S2B-MSIL2A-ST20200122T111720-N0213-R137-T30UYV-20200122T122946.tif', '0000__S2A-MSIL2A-ST20200104T110726-N0213-R094-T30UYV-20200104T122020.tif', '0046__S2B-MSIL2A-ST20200113T104631-N0213-R008-T31UGT-20200113T112959.tif', '0057__S2A-MSIL2A-ST20200217T104629-N0214-R008-T32ULC-20200217T121511.tif', '0057__S2B-MSIL2A-ST20200212T104630-N0214-R008-T32ULC-20200213T134833.tif', '0000__S2A-MSIL2A-ST20200206T111719-N0214-R137-T30UYV-20200206T122704.tif', '0046__S2A-MSIL2A-ST20200207T104628-N0214-R008-T31UGT-20200207T122428.tif', '0046__S2A-MSIL2A-ST20200217T104630-N0214-R008-T31UGT-20200217T121511.tif', '0057__S2A-MSIL2A-ST20200207T104628-N0214-R008-T31UGT-20200207T122428.tif', '0042__S2A-MSIL2A-ST20200207T104627-N0214-R008-T32ULC-20200207T122428.tif', '0057__S2A-MSIL2A-ST20200207T104627-N0214-R008-T32ULC-20200207T122428.tif', '0046__S2A-

In [9]:
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands

VIT_UPERNET_NECK = [
    {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
    {"name": "ReshapeTokensToImage"},
    {"name": "LearnedInterpolateToPyramidal"},
]

model_seg_args = {
        "in_channels": 13,
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
        "backbone_pretrained_cfg_overlay":{"file": "C:/Users/alhst/Documents/AI Master/Urban Computing/Project/Prithvi/Files/Prithvi_EO_V1_100M.pt"}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "pretrained":False,
        "num_classes": 2,
        "necks":  VIT_UPERNET_NECK
}

task = SemanticSegmentationTask(
    model_args=model_seg_args,
    model_factory="PrithviModelFactory",
    loss="ce",
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=True
)

In [None]:
means=[
    960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
    1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
    1914.3, 1512.0585, 1512.0585
    ]  # updated from dataset_multitask file

stds=[
    1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
    1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
    1303.83, 1189.9052, 1189.9052
    ]  # updated from dataset_multitask file 

datamodule_seg = GenericNonGeoClassificationDataModule(
    batch_size=16,
    num_workers=0,
    train_data_root=os.path.join(path, 'data', 'training'),
    val_data_root=os.path.join(path, 'data', 'validation'),
    test_data_root=os.path.join(path, 'data', 'validation'),  # reusing the validation set for testing
    img_grep=
    means=means,
    stds=stds,
    num_classes=4,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transforms,
    # val_transform=val_transforms,
    # test_transform=val_transforms,

    # Bands of your dataset (in this case similar to the model bands)
    dataset_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    # Input bands of your model
    output_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    constant_scale=39.216,  # Scale 0-255 data to 0-10000 (HLS data) (10000 / 255 = 39.216)
    no_data_replace=0,
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

In [108]:
# Trying regression
from regression_tasks import RegressionTask

model_args = {
        "in_channels": 13,
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
        "backbone_pretrained_cfg_overlay":{"file": "C:/Users/alhst/Documents/AI Master/Urban Computing/Project/Prithvi/Files/Prithvi_EO_V1_100M.pt"}, # FUCK THE EO PEOPLE ON HUGGINGFACE FOR RENAMING THE FILE YOU PIECES OF SHIT
        "pretrained":False,
        "num_classes": 1,
        "necks":  VIT_UPERNET_NECK
}

task = RegressionTask(
    model_args=model_args,
    model_factory="PrithviModelFactory",
    # pretrained_cfg=dict(file="Prithvi_EO_V1_100M.pt"),
    loss="mse",  # need cross-entropy for 
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=True,
)

KeyError: 'PrithviModelFactory'

# Ignore all cells below

In [46]:
for batch in dataloader:
    sample = unbind_samples(batch)
    # print(sample[0]['image'])
    indices = []
    bands = tuple([f'B0{i}' for i in range(1,14)])
    for band in bands:
        indices.append(bands.index(band))

    # Reorder and rescale the image
    # image = sample[0]['image'][indices].permute(1, 2, 0)
    image = sample[0]['image'][indices]
    print(image.shape)

NameError: name 'dataloader' is not defined

In [None]:
from terratorch.datasets import HLSBands

batch_size = 1
num_workers = 0
# train_val_test = [
#     "burn_scar_segmentation_toy/train_images",
#     "burn_scar_segmentation_toy/val_images",
#     "burn_scar_segmentation_toy/test_images",
# ]
train_val_test = [
    os.path.join(path, "data\images\images", 'training/300x300/'),
    os.path.join(path, "data\images\images", 'validation/300x300/'),
    os.path.join(path, "data\images\images", 'validation/300x300/')
]

# train_val_test_labels = {
#     "train_label_data_root": "burn_scar_segmentation_toy/train_labels",
#     "val_label_data_root": "burn_scar_segmentation_toy/val_labels",
#     "test_label_data_root": "burn_scar_segmentation_toy/test_labels",
# }
train_val_test_labels = {
    "train_label_data_root": "burn_scar_segmentation_toy/train_labels",
    "val_label_data_root": "burn_scar_segmentation_toy/val_labels",
    "test_label_data_root": "burn_scar_segmentation_toy/test_labels",
}  # still to edit



# from https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/configs/burn_scars.py

means=[
    960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
    1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
    1914.3, 1512.0585
    ]  # updated from dataset_multitask file

stds=[
    1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
    1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
    1303.83, 1189.9052
    ]  # updated from dataset_multitask file 

# datamodule = GenericNonGeoSegmentationDataModule(
#     batch_size,
#     num_workers,
#     *train_val_test,
#     "*_merged.tif", # img grep
#     "*.mask.tif", # label grep
#     means,
#     stds,
#     2, # num classes
#     **train_val_test_labels,

#     # if transforms are defined with Albumentations, you can pass them here
#     # train_transform=train_transform,
#     # val_transform=val_transform,
#     # test_transform=test_transform,

#     # edit the below for your usecase
#     dataset_bands=[
#         HLSBands.BLUE,
#         HLSBands.GREEN,
#         HLSBands.RED,
#         HLSBands.NIR_NARROW,
#         HLSBands.SWIR_1,
#         HLSBands.SWIR_2,
#     ],
#     output_bands=[
#         HLSBands.BLUE,
#         HLSBands.GREEN,
#         HLSBands.RED,
#         HLSBands.NIR_NARROW,
#         HLSBands.SWIR_1,
#         HLSBands.SWIR_2,
#     ],
#     no_data_replace=0,
#     no_label_replace=-1,
# )

datamodule = GenericNonGeoSegmentationDataModule(
    batch_size,
    num_workers,
    *train_val_test,
    img_grep="*_merged.tif", # img grep
    label_grep="*.mask.tif", # label grep
    test_data_root="_merged.tif",
    test_label_data_root=".mask.tif",
    means=means,
    stds=stds,
    num_classes=2, # num classes
    **train_val_test_labels,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transform,
    # val_transform=val_transform,
    # test_transform=test_transform,

    # edit the below for your usecase
    dataset_bands=[
        HLSBands.BLUE,
        HLSBands.GREEN,
        HLSBands.RED,
        HLSBands.NIR_NARROW,
        HLSBands.SWIR_1,
        HLSBands.SWIR_2,
    ],
    output_bands=[
        HLSBands.BLUE,
        HLSBands.GREEN,
        HLSBands.RED,
        HLSBands.NIR_NARROW,
        HLSBands.SWIR_1,
        HLSBands.SWIR_2,
    ],
    no_data_replace=0,
    no_label_replace=-1,
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

TypeError: terratorch.datamodules.generic_pixel_wise_data_module.GenericNonGeoSegmentationDataModule() got multiple values for keyword argument 'test_label_data_root'

### Part 2: Defining Trainer and Custom Dataloader

In [82]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from torchgeo.samplers import GridGeoSampler
from torchgeo.datasets.splits import random_bbox_assignment

checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir='output', name='tutorial')

# You can also log directly to WandB
# from lightning.pytorch.loggers import WandbLogger
# wandb_logger = WandbLogger(log_model="all") 

trainer = Trainer(
    devices=1, # Number of GPUs. Interactive mode recommended with 1 device
    precision="16-mixed",
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        early_stopping_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=1, # train only one epoch for demo
    default_root_dir='output/test',
    log_every_n_steps=1,
    check_val_every_n_epoch=1
)

# for batch_idx, batch in enumerate(train_dl.keys()):
#     # print(batch)
#     print(batch)

class CustomGeoDataModule(GeoDataModule):  # defining a custom datamodule to feed it to the trainer
    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        self.dataset = self.dataset_class(**self.kwargs)
        
        generator = torch.Generator().manual_seed(0)
        (
            self.train_dataset,
            self.val_dataset,
            self.test_dataset,
        ) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)  # not sure what this does yet BUT IT IS VERY NECESSARY
        # Creating train-val-test split???
        
        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(
                self.train_dataset, self.patch_size, self.batch_size, self.length
            )
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(
                self.val_dataset, self.patch_size, self.patch_size
            )
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(
                self.test_dataset, self.patch_size, self.patch_size
            )

custom_datamodule = CustomGeoDataModule(type(dataset), batch_size=2, patch_size=120, length=1)  # runtime error perhaps due to num_workers, could try 0 if commenting this out doesn't work (parallell resources)
custom_datamodule.setup("fit")
# custom_datamodule = GeoDataModule(type(dataset), batch_size=1, patch_size=120, length=1, num_workers=6)  # previous module, doesn't work (gives "split" error)

c:\Users\alhst\anaconda3\envs\terratorch\Lib\site-packages\lightning\pytorch\trainer\connectors\accelerator_connector.py:556: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


NameError: name 'GeoDataModule' is not defined

In [59]:
_ = trainer.fit(model=task, train_dataloaders=custom_datamodule)

ValueError: Expected input batch_size (26) to match target batch_size (2).

In [None]:
print(f"The model was pretrained on bands {task._timm_module.pretrained_bands}.\n The model is using bands {model._timm_module.model_bands}")