In [1]:
!pip install albucore==0.0.16



In [1]:
import os
import torch
import timm

import terratorch
from terratorch.tasks import ClassificationTask, PixelwiseRegressionTask, SemanticSegmentationTask

from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, GeoDataset, UnionDataset
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler,GeoSampler,RandomBatchGeoSampler

import terratorch.models.backbones.prithvi_vit as prithvi_vit
from terratorch.datamodules import GenericNonGeoSegmentationDataModule, GenericNonGeoClassificationDataModule
import pandas as pd
from torch.utils.data import DataLoader, ConcatDataset, RandomSampler

import tqdm
import rasterio as rio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from rasterio.enums import Resampling

from data_preprocessing import create_split, crop_image_and_segmentation

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

#For tim
if 'UCProjectGroup1' in os.getcwd():
    os.chdir('..')

INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.0 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
print(timm.list_pretrained())
print(help(terratorch.tasks.ClassificationTask))

['bat_resnext26ts.ch_in1k', 'beit_base_patch16_224.in22k_ft_in22k', 'beit_base_patch16_224.in22k_ft_in22k_in1k', 'beit_base_patch16_384.in22k_ft_in22k_in1k', 'beit_large_patch16_224.in22k_ft_in22k', 'beit_large_patch16_224.in22k_ft_in22k_in1k', 'beit_large_patch16_384.in22k_ft_in22k_in1k', 'beit_large_patch16_512.in22k_ft_in22k_in1k', 'beitv2_base_patch16_224.in1k_ft_in1k', 'beitv2_base_patch16_224.in1k_ft_in22k', 'beitv2_base_patch16_224.in1k_ft_in22k_in1k', 'beitv2_large_patch16_224.in1k_ft_in1k', 'beitv2_large_patch16_224.in1k_ft_in22k', 'beitv2_large_patch16_224.in1k_ft_in22k_in1k', 'botnet26t_256.c1_in1k', 'caformer_b36.sail_in1k', 'caformer_b36.sail_in1k_384', 'caformer_b36.sail_in22k', 'caformer_b36.sail_in22k_ft_in1k', 'caformer_b36.sail_in22k_ft_in1k_384', 'caformer_m36.sail_in1k', 'caformer_m36.sail_in1k_384', 'caformer_m36.sail_in22k', 'caformer_m36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k_384', 'caformer_s18.sail_in1k', 'caformer_s18.sail_in1k_384', 'caformer_s

## Classification - Finetune Prithvi to act as a classification model

### Training model

Part 1: Constructing custom RasterDataset to transform our data to useable format

### Part 0: Data preprocessing

In [53]:
# Rastering segmentation masks according to json files 
import json
from shapely.geometry import Polygon
from rasterio.features import rasterize, geometry_mask
from rasterio.mask import mask
import geopandas as gpd
import matplotlib.pyplot as plt

if 'UCProjectGroup1' in os.getcwd():
    os.chdir('..')
path = os.getcwd()  # current path
data_path =  os.path.join(os.getcwd(), "data") 
datadir = os.path.join(path, "data","images","images")
seglabeldir = os.path.join(path, "data","segmentation_labels","segmentation_labels")

# create folder to host the new segmentation maps
os.makedirs(f'{path}/data/labels', exist_ok=True)
default_transform = rio.transform.from_bounds(0, 0, 120, 120, width=120, height=120)
seglabels = []
segfile_lookup = {}

# Make a lookup table for each segmentation file
idx = 0
for dirpath,dirnames,files in os.walk(seglabeldir):
    for seglabelfile in files:
        if not os.path.join(dirpath,seglabelfile).endswith(".json"):
            continue
        segdata = json.load(open(os.path.join(dirpath,
                                                seglabelfile), 'r'))
        seglabels.append(segdata)
        seg_key = "-".join(segdata['data']['image'].split('-')[1:]).replace(
                '.png', '.tif')
        seg_key = seg_key.replace(":","_")  # to match with image files
        segfile_lookup[seg_key] = idx
        # segfile_lookup[
        #     "-".join(segdata['data']['image'].split('-')[1:]).replace(
        #         '.png', '.tif')] = idx  # original code
        idx+=1

print(len(segfile_lookup.keys()))

# seglabels_poly = []

# Walking through all images and adding segmentation masks
for root, _, files in os.walk(datadir):
    for filename in files:
        if not filename.endswith('.tif'):
            continue
        # if filename not in segfile_lookup.keys():
        #     continue
        img_path = os.path.join(root, filename)

        # extracting image size
        if "120x120" in img_path:
            size = 120
        elif "300x300" in img_path:
            size = 300
        elif "validation" in img_path:
            size = 120
        else: 
            print("Outlier image size")  # detecting remaining outliers
            print(img_path)
            size = 120

        polygons = []
        if filename in segfile_lookup.keys():  # add polygons if clouds are present in image
            for completions in seglabels[segfile_lookup[filename]]['completions']:
                for result in completions['result']:
                    polygons.append(
                        np.array(
                            result['value']['points'] + [result['value']['points'][0]]) * size / 100)
        with rio.open(img_path, 'r') as src:  # reading image 
            img_file = src.read()

        # rasterize segmentation polygons
        fptdata = np.zeros((img_file.shape[1], img_file.shape[2]), dtype=np.uint8)
        # polygons = seglabels_poly.copy()
        shapes = []
        if len(polygons) > 0: # add polygons if clouds are present in images
            for pol in polygons:
                try:
                    pol = Polygon(pol)
                    shapes.append(pol)
                except ValueError:
                    continue
            polygon_geom = [(g, 1) for g in shapes]
            fptdata = rasterize(shapes,
                                out_shape=fptdata.shape,
                                all_touched=True)
        if len(shapes) <= 0 and (("120x120" in img_path) or ("validation" in img_path)):  # no available segmnetation map 
            # print("No segmentation map available")
            continue
        
        # convert raster to tiff 
        mask_name = f"data/labels/{filename}"
        with rio.open(mask_name, 
                    'w',
                    driver='GTiff',
                    width=size,
                    height=size,
                    dtype=fptdata.dtype,
                    transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
                    count=1) as dst:
            # dst.nodata = 1  # does this help?
            # dst.write_mask(True)  # does this help?
            dst.write(fptdata, 1)  # writing
        
        # if positive, plot corresponding segmentation map
        # if size==120 and "positive" in img_path:
        #     print(fptdata[fptdata != 0])
        #     print(shapes)
        #     print(filename)
            # with rio.open(mask_name, 'r') as src:
            #     mask_file = src.read()
            #     fig, ax = plt.subplots()
            #     plt.imshow(mask_file[0, :, :], cmap='grey')
            #     plt.show()
print(len(os.listdir(os.path.join(os.getcwd(), "data", "labels"))))
print(idx)

dict_keys(['198_2019-02-15T10_06_39.586Z_2.tif', '198_2019-02-18T10_16_37.070Z_3.tif', '198_2019-02-25T10_07_37.455Z_4.tif', '198_2019-02-28T10_16_51.765Z_5.tif', '198_2019-03-20T10_29_41.074Z_6.tif', '198_2019-04-01T10_08_28.455Z_8.tif', '198_2019-04-04T10_20_02.455Z_9.tif', '198_2019-04-06T10_12_22.697Z_10.tif', '198_2019-04-16T10_12_09.368Z_11.tif', '198_2019-04-19T10_25_56.393Z_12.tif', '198_2019-04-21T10_20_36.492Z_13.tif', '198_2019-04-24T10_29_32.298Z_14.tif', '198_2019-05-24T10_16_40.727Z_16.tif', '198_2019-05-19T10_16_46.113Z_15.tif', '198_2019-06-05T10_06_47.801Z_18.tif', '198_2019-06-15T10_06_47.171Z_19.tif', '198_2019-06-30T10_06_44.463Z_22.tif', '198_2019-06-25T10_06_48.055Z_21.tif', '198_2019-07-23T10_16_42.508Z_23.tif', '198_2019-07-25T10_06_48.183Z_24.tif', '198_2019-08-22T10_16_39.689Z_25.tif', '198_2019-08-27T10_16_41.778Z_26.tif', '198_2019-09-01T10_16_37.949Z_27.tif', '198_2019-09-03T10_06_42.656Z_28.tif', '198_2019-09-11T10_16_35.705Z_29.tif', '198_2019-09-21T10_16

In [3]:
# Creating train-validation-test split for classification model
data_path =  os.path.join(os.getcwd(), "data") 
labels = pd.read_csv(os.path.join(data_path,'labels.csv')) 

label_filenames = labels['filename'].values
new_label_filenames = [l.replace(":","_") for l in label_filenames]
labels['filename'] = new_label_filenames

train, val, test= create_split(labels)
data_split ={"training":train, "validation":val, "test":test}
print(data_split["training"].shape, data_split["validation"].shape, data_split["test"].shape)

(1703,) (214,) (214,)


In [4]:
default_transform = rio.transform.from_bounds(0, 0, 120, 120, width=120, height=120)


path = os.getcwd()  # current path

datadir = os.path.join(path, "data","images","images")
reg_file = os.path.join(path, "data","labels.csv")
seglabeldir = os.path.join(path, "data","labels") # directory where segmentation masks are stored as tifs with the same filename as the image it pertains to
# examples = pd.read_csv(dataset_dir / f'{split}.csv')
examples = pd.read_csv(reg_file)  # extracting data from .csv
labels = examples['filename'].values
labels = [l.replace(":","_") for l in labels]  # reformatting filenames in labels.csv
examples['filename'] = labels
print(labels)

for split in data_split.keys():  # 'validation'

    # Create class dirs
    class_names = examples['fuel_type'].unique()  # defining classes 
    for class_name in class_names:
        os.makedirs(f'{path}/data/{split}/{class_name}', exist_ok=True)  # create new directories for train-validation sets

    # if split == 'training':
    #     # file_root = os.path.join(datadir, split, '120x120')  # only focus on the 120x120 in training for now
    #     file_root = os.path.join(datadir, split)  # use all images
    # else: 
    #     file_root = os.path.join(datadir, split)  # only 120x120 images in validation    


CROP_FLAG = True #This controls whether we crop all images to 120x120 or rescale them to 224x224 with bilinear resampling.

for dirpath,dirnames,files in os.walk(datadir):
    for file in files:
        if not file.endswith('.tif'):  # dont need to add file if not a tiff
            continue
            
        filepath = os.path.join(dirpath,file)
        seg_filepath = os.path.join(seglabeldir,file) #segmentation labels are stored in seglabeldir with the same filename

        if file in data_split["training"]:
            split = "training"
        elif file in data_split["validation"]:
            split = "validation"
        elif file in data_split["test"]:
            split = "test"
        else: 
            print("File is not in the list" )
        # print(file)
        
        if not CROP_FLAG:

            with rio.open(filepath) as src:  # open file of 224x224 to write to new directory
                load_file = src.read(
                    out_shape=(
                    src.count,
                    224,
                    224
                    ),
                resampling=Resampling.bilinear
                )
        
        else:
            #use crop functionality to load in the data. Works for both 120x120 and 300x300 and returns a 120x120 image. Also pass along the seg_filepath 
            #so that we know if we should interpolate spatial resolution or not
            load_file, _ = crop_image_and_segmentation(filepath, seg_filepath) # also returns cropped segmentation mask but we do not need it here so dump it in _

        # Stack bands (to match Prithvi channels)
        # BLUE, GREEN, RED, NIR_NARROW, SWIR_1, SWIR_2 / Landsat: B02, B03, B04, B05, B06, B07
        # Note that you don't have to match the channels, you can also define them in the config.
        # stacked = np.concatenate([rgb[[2, 1, 0]], ir.transpose((2, 0, 1))], axis=0)
        # stacked = np.concatenate([band for band in ])

        # Save images in dedicated folder per class
        # out_file = dataset_dir / split / row['label'] / f'{split}_{i}.tif'
        if np.isin(file,labels):  # checking if file is present in .csv
            file_csv = examples[examples['filename']==os.path.basename(os.path.normpath(file))]
            file_label = file_csv['fuel_type'].values[0]  # extract label
            file_index = file_csv.index[0]  # extract .csv index
            out_file = f'data/{split}/{file_label}/{split}_{file_index}.tif'  # new filepath

            if not CROP_FLAG:
                with rio.open(out_file, 
                            'w',
                            driver='GTiff',
                            width=224,  # changed from 120
                            height=224,  # changed from 120
                            dtype=load_file.dtype,
                            transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
                            count=13) as dst:
                    dst.write(load_file)  # writing
            else:
                with rio.open(out_file, 
                            'w',
                            driver='GTiff',
                            width=120,  
                            height=120, 
                            dtype=load_file.dtype,
                            transform=default_transform,  # Adding wrong geotransform to avoid NotGeoreferencedWarning
                            count=13) as dst:
                    dst.write(load_file)  # writing



['0000__S2B-MSIL2A-ST20200122T111720-N0213-R137-T30UYV-20200122T122946.tif', '0046__S2B-MSIL2A-ST20200113T104630-N0213-R008-T32ULC-20200113T112959.tif', '0002__S2B-MSIL2A-ST20200122T111720-N0213-R137-T30UYV-20200122T122946.tif', '0000__S2A-MSIL2A-ST20200104T110726-N0213-R094-T30UYV-20200104T122020.tif', '0046__S2B-MSIL2A-ST20200113T104631-N0213-R008-T31UGT-20200113T112959.tif', '0057__S2A-MSIL2A-ST20200217T104629-N0214-R008-T32ULC-20200217T121511.tif', '0057__S2B-MSIL2A-ST20200212T104630-N0214-R008-T32ULC-20200213T134833.tif', '0000__S2A-MSIL2A-ST20200206T111719-N0214-R137-T30UYV-20200206T122704.tif', '0046__S2A-MSIL2A-ST20200207T104628-N0214-R008-T31UGT-20200207T122428.tif', '0046__S2A-MSIL2A-ST20200217T104630-N0214-R008-T31UGT-20200217T121511.tif', '0057__S2A-MSIL2A-ST20200207T104628-N0214-R008-T31UGT-20200207T122428.tif', '0042__S2A-MSIL2A-ST20200207T104627-N0214-R008-T32ULC-20200207T122428.tif', '0057__S2A-MSIL2A-ST20200207T104627-N0214-R008-T32ULC-20200207T122428.tif', '0046__S2A-

In [5]:
# count number of files in each folder after splitting
for split in data_split.keys():
    split_path = os.path.join(data_path, split)
    file_count = sum(len(files) for _, _, files in os.walk(split_path))
    print(f"{split}: {file_count} files")

training: 1703 files
validation: 214 files
test: 214 files


### Part 1: Defining datamodule for lightning trainer

In [14]:

# # Timm requires 224x224 input.
# train_transforms = albumentations.Compose([
#     albumentations.RandomCrop(height=224, width=224),
#     albumentations.HorizontalFlip(),
#     albumentations.pytorch.transforms.ToTensorV2(),
# ])
# val_transforms = albumentations.Compose([
#     albumentations.CenterCrop(height=224, width=224),
#     albumentations.pytorch.transforms.ToTensorV2(),
# ])

# means of full 13 bands
means_full=[
    960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
    1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
    1914.3, 1512.0585, 1512.0585
    ]  # updated from dataset_multitask file, full
# means of 6 Prithvi bands
means=[1110.9012, 1250.0942, 1259.5178, 2251.6265, 1512.0585, 1512.0585]

# stds of full 13 bands
stds_full=[
    1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
    1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
    1303.83, 1189.9052, 1189.9052
    ]  # updated from dataset_multitask file 
# stds of 6 Prithvi bands
stds=[1418.4988, 1381.5366, 1406.7112, 1604.1998, 1189.9052, 1189.9052]

datamodule = GenericNonGeoClassificationDataModule(
    batch_size=16,
    num_workers=27,
    train_data_root=os.path.join(path, 'data', 'training'),
    val_data_root=os.path.join(path, 'data', 'validation'),
    test_data_root=os.path.join(path, 'data', 'validation'), 
    means=means,
    stds=stds,
    num_classes=6,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transforms,
    # val_transform=val_transforms,
    # test_transform=val_transforms,

    # Bands of your dataset (in this case similar to the model bands)
    dataset_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    # Input bands of your model
    # output_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    output_bands=('B02', 'B03', 'B04','B08','B012', 'B013'),
    constant_scale=39.216,  # Scale 0-255 data to 0-10000 (HLS data) (10000 / 255 = 39.216) -> do we have HLS data?
    no_data_replace=0,  # replace each with 0, with is the mean after normalisation 
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

### Part 2: Defining classification error

In [34]:
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands
path_weights = os.path.join(os.getcwd(), "Prithvi_EO_V1_100M.pt")

VIT_UPERNET_NECK = [
    {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
    {"name": "ReshapeTokensToImage"},
    {"name": "LearnedInterpolateToPyramidal"},
]


model_args = {
        "in_channels": 13,
        "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
        "decoder": "UperNetDecoder",
        # "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),  # full
        "bands": ('B02', 'B03', 'B04','B08','B012', 'B013'),
        "backbone_pretrained_cfg_overlay":{"file": path_weights},
        # "pretrained":False,
        "pretrained":True,
        "num_classes": 6,
        "necks":  VIT_UPERNET_NECK
}

task = ClassificationTask(
    model_args=model_args,
    model_factory="PrithviModelFactory",
    # pretrained_cfg=dict(file="Prithvi_EO_V1_100M.pt"),
    loss="ce",  # cross-entropy loss for 
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.01},
    freeze_backbone=False,
)

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'


### Part 4: Training the model - Initialising and fitting lightning trainer

In [35]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.0, patience=10, verbose=True)  # negative improvement counts as worsening
logger = TensorBoardLogger(save_dir='output', name='classification')

trainer = Trainer(
    devices=1, # Number of GPUs. Interactive mode recommended with 1 device
    precision="16-mixed",
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        early_stopping_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=10, # train for 10 epochs for tuning
    default_root_dir='output/classification',
    log_every_n_steps=1,
    check_val_every_n_epoch=1
)
_ = trainer.fit(model=task, datamodule=datamodule)

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

INFO: Metric val/loss improved. New best score: 1.050
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.050
INFO: Metric val/loss improved by 0.191 >= min_delta = 0.0. New best score: 0.859
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.191 >= min_delta = 0.0. New best score: 0.859
INFO: Metric val/loss improved by 0.261 >= min_delta = 0.0. New best score: 0.598
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.261 >= min_delta = 0.0. New best score: 0.598
INFO: Metric val/loss improved by 0.136 >= min_delta = 0.0. New best score: 0.462
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.136 >= min_delta = 0.0. New best score: 0.462
INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


### Part 5: Testing the finetuned model

In [36]:
res = trainer.test(model=task, datamodule=datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

### Part 6 - Conduct 5 runs of 50 epochs to obtain average and standard deviation

In [38]:
iterations = 5
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands
path_weights = os.path.join(os.getcwd(), "Prithvi_EO_V1_100M.pt")
results = {}

for it in range(iterations):

    # Re-defining model for clean run
    VIT_UPERNET_NECK = [
        {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
        {"name": "ReshapeTokensToImage"},
        {"name": "LearnedInterpolateToPyramidal"},
    ]

    model_args = {
            "in_channels": 13,
            "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
            "decoder": "UperNetDecoder",
            # "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),  # full
            "bands": ('B02', 'B03', 'B04','B08','B012', 'B013'),
            "backbone_pretrained_cfg_overlay":{"file": path_weights},
            # "pretrained":False,
            "pretrained":True,
            "num_classes": 6,
            "necks":  VIT_UPERNET_NECK
    }
    
    task = ClassificationTask(
        model_args=model_args,
        model_factory="PrithviModelFactory",
        loss="ce",  # cross-entropy loss for multiclass classification
        lr=1e-4,
        optimizer="AdamW",
        optimizer_hparams={"weight_decay": 0.01},
        freeze_backbone=False,
    )

     # Running model
    checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
    early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.0, patience=20, verbose=True)  # negative improvement counts as worsening
    logger = TensorBoardLogger(save_dir='output', name='classification')
    
    trainer = Trainer(
        devices=1, # Number of GPUs. Interactive mode recommended with 1 device
        precision="16-mixed",
        callbacks=[
            RichProgressBar(),
            checkpoint_callback,
            early_stopping_callback,
            LearningRateMonitor(logging_interval="epoch"),
        ],
        logger=logger,
        max_epochs=50, # train for 20 epochs for tuning
        default_root_dir='output/classification',
        log_every_n_steps=1,
        check_val_every_n_epoch=1
    )
    _ = trainer.fit(model=task, datamodule=datamodule)

    # Store results
    res = trainer.test(model=task, datamodule=datamodule)
    for param in res[0].keys():
        if it==0:  # add metric in first iteration
            results[param] = []
        results[param].append(res[0][param])
    print(results)

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.261
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.261
INFO: Metric val/loss improved by 0.279 >= min_delta = 0.0. New best score: 0.982
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.279 >= min_delta = 0.0. New best score: 0.982
INFO: Metric val/loss improved by 0.254 >= min_delta = 0.0. New best score: 0.728
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.254 >= min_delta = 0.0. New best score: 0.728
INFO: Metric val/loss improved by 0.123 >= min_delta = 0.0. New best score: 0.605
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.123 >= min_delta = 0.0. New best score: 0.605
INFO: Metric val/loss improved by 0.102 >= min_delta = 0.0. New best score: 0.504
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.102 >= min_delta = 0.0. New best score: 0.504
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.4581335484981537], 'test/Average_Accuracy': [0.7606539726257324], 'test/multiclassaccuracy_0': [0.9322034120559692], 'test/multiclassaccuracy_1': [1.0], 'test/multiclassaccuracy_2': [0.75], 'test/multiclassaccuracy_3': [0.8817204236984253], 'test/multiclassaccuracy_4': [0.0], 'test/multiclassaccuracy_5': [1.0], 'test/Multiclass_F1_Score': [0.8644859790802002], 'test/Multiclass_Jaccard_Index': [0.708715558052063], 'test/multiclassjaccardindex_0': [0.8088235259056091], 'test/multiclassjaccardindex_1': [1.0], 'test/multiclassjaccardindex_2': [0.6842105388641357], 'test/multiclassjaccardindex_3': [0.7592592835426331], 'test/multiclassjaccardindex_4': [0.0], 'test/multiclassjaccardindex_5': [1.0], 'test/Overall_Accuracy': [0.8644859790802002]}


INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.510
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.510
INFO: Metric val/loss improved by 0.143 >= min_delta = 0.0. New best score: 1.366
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.143 >= min_delta = 0.0. New best score: 1.366
INFO: Metric val/loss improved by 0.234 >= min_delta = 0.0. New best score: 1.133
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.234 >= min_delta = 0.0. New best score: 1.133
INFO: Metric val/loss improved by 0.096 >= min_delta = 0.0. New best score: 1.037
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.096 >= min_delta = 0.0. New best score: 1.037
INFO: Metric val/loss improved by 0.306 >= min_delta = 0.0. New best score: 0.731
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.306 >= min_delta = 0.0. New best score: 0.731
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.4581335484981537, 0.5892640948295593], 'test/Average_Accuracy': [0.7606539726257324, 0.7481553554534912], 'test/multiclassaccuracy_0': [0.9322034120559692, 0.7627118825912476], 'test/multiclassaccuracy_1': [1.0, 1.0], 'test/multiclassaccuracy_2': [0.75, 0.7692307829856873], 'test/multiclassaccuracy_3': [0.8817204236984253, 0.9569892287254333], 'test/multiclassaccuracy_4': [0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0], 'test/Multiclass_F1_Score': [0.8644859790802002, 0.855140209197998], 'test/Multiclass_Jaccard_Index': [0.708715558052063, 0.702407717704773], 'test/multiclassjaccardindex_0': [0.8088235259056091, 0.725806474685669], 'test/multiclassjaccardindex_1': [1.0, 1.0], 'test/multiclassjaccardindex_2': [0.6842105388641357, 0.7407407164573669], 'test/multiclassjaccardindex_3': [0.7592592835426331, 0.7478991746902466], 'test/multiclassjaccardindex_4': [0.0, 0.0], 'test/multiclassjaccardindex_5': [1.0, 1.0], 'test/Overall_Accuracy': [0.8644859790802002, 0.855140

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.443
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.443
INFO: Metric val/loss improved by 0.169 >= min_delta = 0.0. New best score: 1.274
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.169 >= min_delta = 0.0. New best score: 1.274
INFO: Metric val/loss improved by 0.461 >= min_delta = 0.0. New best score: 0.814
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.461 >= min_delta = 0.0. New best score: 0.814
INFO: Metric val/loss improved by 0.216 >= min_delta = 0.0. New best score: 0.597
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.216 >= min_delta = 0.0. New best score: 0.597
INFO: Metric val/loss improved by 0.113 >= min_delta = 0.0. New best score: 0.485
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.113 >= min_delta = 0.0. New best score: 0.485
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.4581335484981537, 0.5892640948295593, 0.7729626893997192], 'test/Average_Accuracy': [0.7606539726257324, 0.7481553554534912, 0.7164901494979858], 'test/multiclassaccuracy_0': [0.9322034120559692, 0.7627118825912476, 0.6610169410705566], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.75, 0.7692307829856873, 0.7884615659713745], 'test/multiclassaccuracy_3': [0.8817204236984253, 0.9569892287254333, 0.8494623899459839], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.8644859790802002, 0.855140209197998, 0.7850467562675476], 'test/Multiclass_Jaccard_Index': [0.708715558052063, 0.702407717704773, 0.6469053030014038], 'test/multiclassjaccardindex_0': [0.8088235259056091, 0.725806474685669, 0.6000000238418579], 'test/multiclassjaccardindex_1': [1.0, 1.0, 1.0], 'test/multiclassjaccardindex_2': [0.6842105388641357, 0.7407407164573669, 0.611940324306488], 'test/multiclassjacc

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.093
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.093
INFO: Metric val/loss improved by 0.293 >= min_delta = 0.0. New best score: 0.800
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.293 >= min_delta = 0.0. New best score: 0.800
INFO: Metric val/loss improved by 0.303 >= min_delta = 0.0. New best score: 0.497
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.303 >= min_delta = 0.0. New best score: 0.497
INFO: Metric val/loss improved by 0.025 >= min_delta = 0.0. New best score: 0.472
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.025 >= min_delta = 0.0. New best score: 0.472
INFO: Metric val/loss improved by 0.046 >= min_delta = 0.0. New best score: 0.427
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.046 >= min_delta = 0.0. New best score: 0.427
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.4581335484981537, 0.5892640948295593, 0.7729626893997192, 0.7073894143104553], 'test/Average_Accuracy': [0.7606539726257324, 0.7481553554534912, 0.7164901494979858, 0.7615798115730286], 'test/multiclassaccuracy_0': [0.9322034120559692, 0.7627118825912476, 0.6610169410705566, 1.0], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.75, 0.7692307829856873, 0.7884615659713745, 0.7307692170143127], 'test/multiclassaccuracy_3': [0.8817204236984253, 0.9569892287254333, 0.8494623899459839, 0.8387096524238586], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.8644859790802002, 0.855140209197998, 0.7850467562675476, 0.8598130941390991], 'test/Multiclass_Jaccard_Index': [0.708715558052063, 0.702407717704773, 0.6469053030014038, 0.7041366696357727], 'test/multiclassjaccardindex_0': [0.8088235259056091, 0.725806474685669, 0.6000000238418579, 0.7564102411270142], 'tes

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.112
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.112
INFO: Metric val/loss improved by 0.343 >= min_delta = 0.0. New best score: 0.769
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.343 >= min_delta = 0.0. New best score: 0.769
INFO: Metric val/loss improved by 0.245 >= min_delta = 0.0. New best score: 0.524
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.245 >= min_delta = 0.0. New best score: 0.524
INFO: Metric val/loss improved by 0.008 >= min_delta = 0.0. New best score: 0.515
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.008 >= min_delta = 0.0. New best score: 0.515
INFO: Metric val/loss improved by 0.096 >= min_delta = 0.0. New best score: 0.420
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.096 >= min_delta = 0.0. New best score: 0.420
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.4581335484981537, 0.5892640948295593, 0.7729626893997192, 0.7073894143104553, 0.5371072292327881], 'test/Average_Accuracy': [0.7606539726257324, 0.7481553554534912, 0.7164901494979858, 0.7615798115730286, 0.7657616138458252], 'test/multiclassaccuracy_0': [0.9322034120559692, 0.7627118825912476, 0.6610169410705566, 1.0, 0.8305084705352783], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.75, 0.7692307829856873, 0.7884615659713745, 0.7307692170143127, 0.9038461446762085], 'test/multiclassaccuracy_3': [0.8817204236984253, 0.9569892287254333, 0.8494623899459839, 0.8387096524238586, 0.8602150678634644], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.8644859790802002, 0.855140209197998, 0.7850467562675476, 0.8598130941390991, 0.8644859790802002], 'test/Multiclass_Jaccard_Index': [0.708715558052063, 0.702407717704773, 0.6469053030014038, 0.70

In [39]:
# print total results    
print(results)
for param in results.keys():
    print(f'Average {param}= {np.mean(results[param])}+={np.std(results[param])}')

{'test/loss': [0.4581335484981537, 0.5892640948295593, 0.7729626893997192, 0.7073894143104553, 0.5371072292327881], 'test/Average_Accuracy': [0.7606539726257324, 0.7481553554534912, 0.7164901494979858, 0.7615798115730286, 0.7657616138458252], 'test/multiclassaccuracy_0': [0.9322034120559692, 0.7627118825912476, 0.6610169410705566, 1.0, 0.8305084705352783], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.75, 0.7692307829856873, 0.7884615659713745, 0.7307692170143127, 0.9038461446762085], 'test/multiclassaccuracy_3': [0.8817204236984253, 0.9569892287254333, 0.8494623899459839, 0.8387096524238586, 0.8602150678634644], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.8644859790802002, 0.855140209197998, 0.7850467562675476, 0.8598130941390991, 0.8644859790802002], 'test/Multiclass_Jaccard_Index': [0.708715558052063, 0.702407717704773, 0.6469053030014038, 0.70

In [40]:
iterations = 5
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands
path_weights = os.path.join(os.getcwd(), "Prithvi_EO_V1_100M.pt")
results = {}

# means of full 13 bands
means_full=[
    960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
    1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
    1914.3, 1512.0585, 1512.0585
    ]  # updated from dataset_multitask file, full
# means of 6 Prithvi bands
means=[1110.9012, 1250.0942, 1259.5178, 2251.6265, 1512.0585, 1512.0585]

# stds of full 13 bands
stds_full=[
    1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
    1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
    1303.83, 1189.9052, 1189.9052
    ]  # updated from dataset_multitask file 
# stds of 6 Prithvi bands
stds=[1418.4988, 1381.5366, 1406.7112, 1604.1998, 1189.9052, 1189.9052]

datamodule = GenericNonGeoClassificationDataModule(
    batch_size=16,
    num_workers=27,
    train_data_root=os.path.join(path, 'data', 'training'),
    val_data_root=os.path.join(path, 'data', 'validation'),
    test_data_root=os.path.join(path, 'data', 'test'), 
    means=means_full,
    stds=stds_full,
    num_classes=6,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transforms,
    # val_transform=val_transforms,
    # test_transform=val_transforms,

    # Bands of your dataset (in this case similar to the model bands)
    dataset_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    # Input bands of your model
    output_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    # output_bands=('B02', 'B03', 'B04','B08','B012', 'B013'),
    constant_scale=39.216,  # Scale 0-255 data to 0-10000 (HLS data) (10000 / 255 = 39.216) -> do we have HLS data?
    no_data_replace=0,  # replace each with 0, with is the mean after normalisation 
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

# Training model on full 13-bands
for it in range(iterations):

    # Re-defining model for clean run
    VIT_UPERNET_NECK = [
        {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
        {"name": "ReshapeTokensToImage"},
        {"name": "LearnedInterpolateToPyramidal"},
    ]

    model_args = {
            "in_channels": 13,
            "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
            "decoder": "UperNetDecoder",
            "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),  # full
            # "bands": ('B02', 'B03', 'B04','B08','B012', 'B013'),
            "backbone_pretrained_cfg_overlay":{"file": path_weights},
            # "pretrained":False,
            "pretrained":True,
            "num_classes": 6,
            "necks":  VIT_UPERNET_NECK
    }
    
    task = ClassificationTask(
        model_args=model_args,
        model_factory="PrithviModelFactory",
        loss="ce",  # cross-entropy loss for multiclass classification
        lr=1e-4,
        optimizer="AdamW",
        optimizer_hparams={"weight_decay": 0.01},
        freeze_backbone=False,
    )

     # Running model
    checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
    early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.0, patience=20, verbose=True)  # negative improvement counts as worsening
    logger = TensorBoardLogger(save_dir='output', name='classification')
    
    trainer = Trainer(
        devices=1, # Number of GPUs. Interactive mode recommended with 1 device
        precision="16-mixed",
        callbacks=[
            RichProgressBar(),
            checkpoint_callback,
            early_stopping_callback,
            LearningRateMonitor(logging_interval="epoch"),
        ],
        logger=logger,
        max_epochs=50, # train for 20 epochs for tuning
        default_root_dir='output/classification',
        log_every_n_steps=1,
        check_val_every_n_epoch=1
    )
    _ = trainer.fit(model=task, datamodule=datamodule)

    # Store results
    res = trainer.test(model=task, datamodule=datamodule)
    for param in res[0].keys():
        if it==0:  # add metric in first iteration
            results[param] = []
        results[param].append(res[0][param])
    print(results)

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.067
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.067
INFO: Metric val/loss improved by 0.459 >= min_delta = 0.0. New best score: 0.607
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.459 >= min_delta = 0.0. New best score: 0.607
INFO: Metric val/loss improved by 0.020 >= min_delta = 0.0. New best score: 0.588
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.020 >= min_delta = 0.0. New best score: 0.588
INFO: Metric val/loss improved by 0.008 >= min_delta = 0.0. New best score: 0.580
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.008 >= min_delta = 0.0. New best score: 0.580
INFO: Metric val/loss improved by 0.104 >= min_delta = 0.0. New best score: 0.476
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.104 >= min_delta = 0.0. New best score: 0.476
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.3391992449760437], 'test/Average_Accuracy': [0.7864599227905273], 'test/multiclassaccuracy_0': [0.8983050584793091], 'test/multiclassaccuracy_1': [1.0], 'test/multiclassaccuracy_2': [0.8999999761581421], 'test/multiclassaccuracy_3': [0.9204545617103577], 'test/multiclassaccuracy_4': [0.0], 'test/multiclassaccuracy_5': [1.0], 'test/Multiclass_F1_Score': [0.9065420627593994], 'test/Multiclass_Jaccard_Index': [0.7480905055999756], 'test/multiclassjaccardindex_0': [0.7910447716712952], 'test/multiclassjaccardindex_1': [1.0], 'test/multiclassjaccardindex_2': [0.8709677457809448], 'test/multiclassjaccardindex_3': [0.8265306353569031], 'test/multiclassjaccardindex_4': [0.0], 'test/multiclassjaccardindex_5': [1.0], 'test/Overall_Accuracy': [0.9065420627593994]}


INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.201
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.201
INFO: Metric val/loss improved by 0.590 >= min_delta = 0.0. New best score: 0.611
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.590 >= min_delta = 0.0. New best score: 0.611
INFO: Metric val/loss improved by 0.227 >= min_delta = 0.0. New best score: 0.384
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.227 >= min_delta = 0.0. New best score: 0.384
INFO: Metric val/loss improved by 0.000 >= min_delta = 0.0. New best score: 0.384
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.000 >= min_delta = 0.0. New best score: 0.384
INFO: Metric val/loss improved by 0.130 >= min_delta = 0.0. New best score: 0.254
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.130 >= min_delta = 0.0. New best score: 0.254
INFO: `Trainer.fit` stoppe

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.3391992449760437, 0.37173789739608765], 'test/Average_Accuracy': [0.7864599227905273, 0.7764060497283936], 'test/multiclassaccuracy_0': [0.8983050584793091, 0.9152542352676392], 'test/multiclassaccuracy_1': [1.0, 1.0], 'test/multiclassaccuracy_2': [0.8999999761581421, 0.800000011920929], 'test/multiclassaccuracy_3': [0.9204545617103577, 0.9431818127632141], 'test/multiclassaccuracy_4': [0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0], 'test/Multiclass_F1_Score': [0.9065420627593994, 0.8925233483314514], 'test/Multiclass_Jaccard_Index': [0.7480905055999756, 0.7323979139328003], 'test/multiclassjaccardindex_0': [0.7910447716712952, 0.7605633735656738], 'test/multiclassjaccardindex_1': [1.0, 1.0], 'test/multiclassjaccardindex_2': [0.8709677457809448, 0.7868852615356445], 'test/multiclassjaccardindex_3': [0.8265306353569031, 0.8469387888908386], 'test/multiclassjaccardindex_4': [0.0, 0.0], 'test/multiclassjaccardindex_5': [1.0, 1.0], 'test/Overall_Accuracy': [0.90654206

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.243
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.243
INFO: Metric val/loss improved by 0.069 >= min_delta = 0.0. New best score: 1.174
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.069 >= min_delta = 0.0. New best score: 1.174
INFO: Metric val/loss improved by 0.263 >= min_delta = 0.0. New best score: 0.912
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.263 >= min_delta = 0.0. New best score: 0.912
INFO: Metric val/loss improved by 0.053 >= min_delta = 0.0. New best score: 0.859
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.053 >= min_delta = 0.0. New best score: 0.859
INFO: Metric val/loss improved by 0.403 >= min_delta = 0.0. New best score: 0.456
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.403 >= min_delta = 0.0. New best score: 0.456
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.3391992449760437, 0.37173789739608765, 0.4083777666091919], 'test/Average_Accuracy': [0.7864599227905273, 0.7764060497283936, 0.7768939137458801], 'test/multiclassaccuracy_0': [0.8983050584793091, 0.9152542352676392, 1.0], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.8999999761581421, 0.800000011920929, 0.8999999761581421], 'test/multiclassaccuracy_3': [0.9204545617103577, 0.9431818127632141, 0.7613636255264282], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.9065420627593994, 0.8925233483314514, 0.8691588640213013], 'test/Multiclass_Jaccard_Index': [0.7480905055999756, 0.7323979139328003, 0.7187793850898743], 'test/multiclassjaccardindex_0': [0.7910447716712952, 0.7605633735656738, 0.7662337422370911], 'test/multiclassjaccardindex_1': [1.0, 1.0, 1.0], 'test/multiclassjaccardindex_2': [0.8709677457809448, 0.7868852615356445, 0.8181818127632141], 'test/multiclass

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.061
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.061
INFO: Metric val/loss improved by 0.134 >= min_delta = 0.0. New best score: 0.928
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.134 >= min_delta = 0.0. New best score: 0.928
INFO: Metric val/loss improved by 0.301 >= min_delta = 0.0. New best score: 0.626
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.301 >= min_delta = 0.0. New best score: 0.626
INFO: Metric val/loss improved by 0.068 >= min_delta = 0.0. New best score: 0.558
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.068 >= min_delta = 0.0. New best score: 0.558
INFO: Metric val/loss improved by 0.116 >= min_delta = 0.0. New best score: 0.443
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.116 >= min_delta = 0.0. New best score: 0.443
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.3391992449760437, 0.37173789739608765, 0.4083777666091919, 0.49227336049079895], 'test/Average_Accuracy': [0.7864599227905273, 0.7764060497283936, 0.7768939137458801, 0.7696028351783752], 'test/multiclassaccuracy_0': [0.8983050584793091, 0.9152542352676392, 1.0, 0.9661017060279846], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.8999999761581421, 0.800000011920929, 0.8999999761581421, 0.8333333134651184], 'test/multiclassaccuracy_3': [0.9204545617103577, 0.9431818127632141, 0.7613636255264282, 0.8181818127632141], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.9065420627593994, 0.8925233483314514, 0.8691588640213013, 0.8644859790802002], 'test/Multiclass_Jaccard_Index': [0.7480905055999756, 0.7323979139328003, 0.7187793850898743, 0.7123656272888184], 'test/multiclassjaccardindex_0': [0.7910447716712952, 0.7605633735656738, 0.7662337422370911, 0.7307

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.241
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.241
INFO: Metric val/loss improved by 0.300 >= min_delta = 0.0. New best score: 0.942
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.300 >= min_delta = 0.0. New best score: 0.942
INFO: Metric val/loss improved by 0.015 >= min_delta = 0.0. New best score: 0.926
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.015 >= min_delta = 0.0. New best score: 0.926
INFO: Metric val/loss improved by 0.211 >= min_delta = 0.0. New best score: 0.715
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.211 >= min_delta = 0.0. New best score: 0.715
INFO: Metric val/loss improved by 0.044 >= min_delta = 0.0. New best score: 0.671
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.044 >= min_delta = 0.0. New best score: 0.671
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.3391992449760437, 0.37173789739608765, 0.4083777666091919, 0.49227336049079895, 0.5115923285484314], 'test/Average_Accuracy': [0.7864599227905273, 0.7764060497283936, 0.7768939137458801, 0.7696028351783752, 0.771513819694519], 'test/multiclassaccuracy_0': [0.8983050584793091, 0.9152542352676392, 1.0, 0.9661017060279846, 0.8813559412956238], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.8999999761581421, 0.800000011920929, 0.8999999761581421, 0.8333333134651184, 0.8500000238418579], 'test/multiclassaccuracy_3': [0.9204545617103577, 0.9431818127632141, 0.7613636255264282, 0.8181818127632141, 0.8977272510528564], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.9065420627593994, 0.8925233483314514, 0.8691588640213013, 0.8644859790802002, 0.8785046935081482], 'test/Multiclass_Jaccard_Index': [0.7480905055999756, 0.7323979139328003, 0.71877

In [41]:
# print total results of 13 bands  
print(results)
for param in results.keys():
    print(f'Average {param}= {np.mean(results[param])}+={np.std(results[param])}')

{'test/loss': [0.3391992449760437, 0.37173789739608765, 0.4083777666091919, 0.49227336049079895, 0.5115923285484314], 'test/Average_Accuracy': [0.7864599227905273, 0.7764060497283936, 0.7768939137458801, 0.7696028351783752, 0.771513819694519], 'test/multiclassaccuracy_0': [0.8983050584793091, 0.9152542352676392, 1.0, 0.9661017060279846, 0.8813559412956238], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.8999999761581421, 0.800000011920929, 0.8999999761581421, 0.8333333134651184, 0.8500000238418579], 'test/multiclassaccuracy_3': [0.9204545617103577, 0.9431818127632141, 0.7613636255264282, 0.8181818127632141, 0.8977272510528564], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.9065420627593994, 0.8925233483314514, 0.8691588640213013, 0.8644859790802002, 0.8785046935081482], 'test/Multiclass_Jaccard_Index': [0.7480905055999756, 0.7323979139328003, 0.71877

In [42]:
iterations = 5
pretrained_bands = prithvi_vit.PRETRAINED_BANDS  # need to still select the correct bands
path_weights = os.path.join(os.getcwd(), "Prithvi_EO_V1_100M.pt")
results = {}

# means of full 13 bands
means_full=[
    960.97437, 1110.9012, 1250.0942, 1259.5178, 1500.98,
    1989.6344, 2155.846, 2251.6265, 2272.9438, 2442.6206,
    1914.3, 1512.0585, 1512.0585
    ]  # updated from dataset_multitask file, full
# means of 6 Prithvi bands
means=[1110.9012, 1250.0942, 1259.5178, 2251.6265, 1512.0585, 1512.0585]

# stds of full 13 bands
stds_full=[
    1302.0157, 1418.4988, 1381.5366, 1406.7112, 1387.4155,
    1438.8479, 1497.8815, 1604.1998, 1516.532, 1827.3025, 
    1303.83, 1189.9052, 1189.9052
    ]  # updated from dataset_multitask file 
# stds of 6 Prithvi bands
stds=[1418.4988, 1381.5366, 1406.7112, 1604.1998, 1189.9052, 1189.9052]

datamodule = GenericNonGeoClassificationDataModule(
    batch_size=16,
    num_workers=27,
    train_data_root=os.path.join(path, 'data', 'training'),
    val_data_root=os.path.join(path, 'data', 'validation'),
    test_data_root=os.path.join(path, 'data', 'test'), 
    means=means,
    stds=stds,
    num_classes=6,

    # if transforms are defined with Albumentations, you can pass them here
    # train_transform=train_transforms,
    # val_transform=val_transforms,
    # test_transform=val_transforms,

    # Bands of your dataset (in this case similar to the model bands)
    dataset_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    # Input bands of your model
    # output_bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),
    output_bands=('B02', 'B03', 'B04','B08','B012', 'B013'),
    constant_scale=39.216,  # Scale 0-255 data to 0-10000 (HLS data) (10000 / 255 = 39.216) -> do we have HLS data?
    no_data_replace=0,  # replace each with 0, with is the mean after normalisation 
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

# Training model on full 13-bands
for it in range(iterations):

    # Re-defining model for clean run
    VIT_UPERNET_NECK = [
        {"name": "SelectIndices", "indices": [1, 2, 3, 4]},
        {"name": "ReshapeTokensToImage"},
        {"name": "LearnedInterpolateToPyramidal"},
    ]

    model_args = {
            "in_channels": 6,
            "backbone": "prithvi_vit_100", # see timm.list_pretrained() 
            "decoder": "UperNetDecoder",
            # "bands": ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B010', 'B011', 'B012', 'B013'),  # full
            "bands": ('B02', 'B03', 'B04','B08','B012', 'B013'),
            "backbone_pretrained_cfg_overlay":{"file": path_weights},
            # "pretrained":False,
            "pretrained":True,
            "num_classes": 6,
            "necks":  VIT_UPERNET_NECK
    }
    
    task = ClassificationTask(
        model_args=model_args,
        model_factory="PrithviModelFactory",
        loss="ce",  # cross-entropy loss for multiclass classification
        lr=1e-4,
        optimizer="AdamW",
        optimizer_hparams={"weight_decay": 0.01},
        freeze_backbone=False,
    )

     # Running model
    checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
    early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.0, patience=20, verbose=True)  # negative improvement counts as worsening
    logger = TensorBoardLogger(save_dir='output', name='classification')
    
    trainer = Trainer(
        devices=1, # Number of GPUs. Interactive mode recommended with 1 device
        precision="16-mixed",
        callbacks=[
            RichProgressBar(),
            checkpoint_callback,
            early_stopping_callback,
            LearningRateMonitor(logging_interval="epoch"),
        ],
        logger=logger,
        max_epochs=50, # train for 20 epochs for tuning
        default_root_dir='output/classification',
        log_every_n_steps=1,
        check_val_every_n_epoch=1
    )
    _ = trainer.fit(model=task, datamodule=datamodule)

    # Store results
    res = trainer.test(model=task, datamodule=datamodule)
    for param in res[0].keys():
        if it==0:  # add metric in first iteration
            results[param] = []
        results[param].append(res[0][param])
    print(results)

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 2.613
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 2.613
INFO: Metric val/loss improved by 0.868 >= min_delta = 0.0. New best score: 1.745
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.868 >= min_delta = 0.0. New best score: 1.745
INFO: Metric val/loss improved by 0.574 >= min_delta = 0.0. New best score: 1.171
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.574 >= min_delta = 0.0. New best score: 1.171
INFO: Metric val/loss improved by 0.419 >= min_delta = 0.0. New best score: 0.752
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.419 >= min_delta = 0.0. New best score: 0.752
INFO: Metric val/loss improved by 0.086 >= min_delta = 0.0. New best score: 0.666
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.086 >= min_delta = 0.0. New best score: 0.666
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.45407089591026306], 'test/Average_Accuracy': [0.7788050174713135], 'test/multiclassaccuracy_0': [0.9152542352676392], 'test/multiclassaccuracy_1': [1.0], 'test/multiclassaccuracy_2': [0.9166666865348816], 'test/multiclassaccuracy_3': [0.8409090638160706], 'test/multiclassaccuracy_4': [0.0], 'test/multiclassaccuracy_5': [1.0], 'test/Multiclass_F1_Score': [0.8831775784492493], 'test/Multiclass_Jaccard_Index': [0.6442587375640869], 'test/multiclassjaccardindex_0': [0.7605633735656738], 'test/multiclassjaccardindex_1': [0.800000011920929], 'test/multiclassjaccardindex_2': [0.859375], 'test/multiclassjaccardindex_3': [0.7789473533630371], 'test/multiclassjaccardindex_4': [0.0], 'test/multiclassjaccardindex_5': [0.6666666865348816], 'test/Overall_Accuracy': [0.8831775784492493]}


INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.046
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.046
INFO: Metric val/loss improved by 0.029 >= min_delta = 0.0. New best score: 1.017
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.029 >= min_delta = 0.0. New best score: 1.017
INFO: Metric val/loss improved by 0.373 >= min_delta = 0.0. New best score: 0.644
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.373 >= min_delta = 0.0. New best score: 0.644
INFO: Metric val/loss improved by 0.035 >= min_delta = 0.0. New best score: 0.609
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.035 >= min_delta = 0.0. New best score: 0.609
INFO: Metric val/loss improved by 0.019 >= min_delta = 0.0. New best score: 0.590
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.019 >= min_delta = 0.0. New best score: 0.590
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.45407089591026306, 0.5348568558692932], 'test/Average_Accuracy': [0.7788050174713135, 0.7543442845344543], 'test/multiclassaccuracy_0': [0.9152542352676392, 0.7457627058029175], 'test/multiclassaccuracy_1': [1.0, 1.0], 'test/multiclassaccuracy_2': [0.9166666865348816, 0.9166666865348816], 'test/multiclassaccuracy_3': [0.8409090638160706, 0.8636363744735718], 'test/multiclassaccuracy_4': [0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0], 'test/Multiclass_F1_Score': [0.8831775784492493, 0.8457943797111511], 'test/Multiclass_Jaccard_Index': [0.6442587375640869, 0.6431959271430969], 'test/multiclassjaccardindex_0': [0.7605633735656738, 0.7213114500045776], 'test/multiclassjaccardindex_1': [0.800000011920929, 1.0], 'test/multiclassjaccardindex_2': [0.859375, 0.7333333492279053], 'test/multiclassjaccardindex_3': [0.7789473533630371, 0.737864077091217], 'test/multiclassjaccardindex_4': [0.0, 0.0], 'test/multiclassjaccardindex_5': [0.6666666865348816, 0.6666666865348816], 't

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.081
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.081
INFO: Metric val/loss improved by 0.422 >= min_delta = 0.0. New best score: 0.659
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.422 >= min_delta = 0.0. New best score: 0.659
INFO: Metric val/loss improved by 0.070 >= min_delta = 0.0. New best score: 0.589
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.070 >= min_delta = 0.0. New best score: 0.589
INFO: Metric val/loss improved by 0.225 >= min_delta = 0.0. New best score: 0.364
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.225 >= min_delta = 0.0. New best score: 0.364
INFO: Metric val/loss improved by 0.030 >= min_delta = 0.0. New best score: 0.334
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.030 >= min_delta = 0.0. New best score: 0.334
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.45407089591026306, 0.5348568558692932, 0.31282293796539307], 'test/Average_Accuracy': [0.7788050174713135, 0.7543442845344543, 0.7796567678451538], 'test/multiclassaccuracy_0': [0.9152542352676392, 0.7457627058029175, 0.9491525292396545], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.9166666865348816, 0.9166666865348816, 0.9333333373069763], 'test/multiclassaccuracy_3': [0.8409090638160706, 0.8636363744735718, 0.7954545617103577], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.8831775784492493, 0.8457943797111511, 0.8785046935081482], 'test/Multiclass_Jaccard_Index': [0.6442587375640869, 0.6431959271430969, 0.726899266242981], 'test/multiclassjaccardindex_0': [0.7605633735656738, 0.7213114500045776, 0.8358209133148193], 'test/multiclassjaccardindex_1': [0.800000011920929, 1.0, 1.0], 'test/multiclassjaccardindex_2': [0.859375, 0.7333333492279053, 0.788732409477233

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.516
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.516
INFO: Metric val/loss improved by 0.481 >= min_delta = 0.0. New best score: 1.036
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.481 >= min_delta = 0.0. New best score: 1.036
INFO: Metric val/loss improved by 0.166 >= min_delta = 0.0. New best score: 0.869
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.166 >= min_delta = 0.0. New best score: 0.869
INFO: Metric val/loss improved by 0.319 >= min_delta = 0.0. New best score: 0.550
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.319 >= min_delta = 0.0. New best score: 0.550
INFO: Metric val/loss improved by 0.063 >= min_delta = 0.0. New best score: 0.488
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.063 >= min_delta = 0.0. New best score: 0.488
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.45407089591026306, 0.5348568558692932, 0.31282293796539307, 0.3503466844558716], 'test/Average_Accuracy': [0.7788050174713135, 0.7543442845344543, 0.7796567678451538, 0.777133584022522], 'test/multiclassaccuracy_0': [0.9152542352676392, 0.7457627058029175, 0.9491525292396545, 0.8135592937469482], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.9166666865348816, 0.9166666865348816, 0.9333333373069763, 0.8833333253860474], 'test/multiclassaccuracy_3': [0.8409090638160706, 0.8636363744735718, 0.7954545617103577, 0.9659090638160706], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.8831775784492493, 0.8457943797111511, 0.8785046935081482, 0.8971962332725525], 'test/Multiclass_Jaccard_Index': [0.6442587375640869, 0.6431959271430969, 0.726899266242981, 0.7364679574966431], 'test/multiclassjaccardindex_0': [0.7605633735656738, 0.7213114500045776, 0.8358209133

INFO:timm.models._builder:Loading pretrained weights from file (/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt)
  checkpoint = torch.load(checkpoint_path, map_location=device)
INFO:timm.models._helpers:Loaded  from checkpoint '/vol/home/s2267063/UC Project/Prithvi_EO_V1_100M.pt'
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/vol/home/s2267063/.conda/envs/terratorch/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU a

Output()

INFO: Metric val/loss improved. New best score: 1.218
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved. New best score: 1.218
INFO: Metric val/loss improved by 0.168 >= min_delta = 0.0. New best score: 1.050
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.168 >= min_delta = 0.0. New best score: 1.050
INFO: Metric val/loss improved by 0.023 >= min_delta = 0.0. New best score: 1.027
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.023 >= min_delta = 0.0. New best score: 1.027
INFO: Metric val/loss improved by 0.408 >= min_delta = 0.0. New best score: 0.620
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.408 >= min_delta = 0.0. New best score: 0.620
INFO: Metric val/loss improved by 0.009 >= min_delta = 0.0. New best score: 0.610
INFO:lightning.pytorch.callbacks.early_stopping:Metric val/loss improved by 0.009 >= min_delta = 0.0. New best score: 0.610
INFO: Metric val/loss impr

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

{'test/loss': [0.45407089591026306, 0.5348568558692932, 0.31282293796539307, 0.3503466844558716, 0.2745228409767151], 'test/Average_Accuracy': [0.7788050174713135, 0.7543442845344543, 0.7796567678451538, 0.777133584022522, 0.7790104150772095], 'test/multiclassaccuracy_0': [0.9152542352676392, 0.7457627058029175, 0.9491525292396545, 0.8135592937469482, 0.8983050584793091], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.9166666865348816, 0.9166666865348816, 0.9333333373069763, 0.8833333253860474, 0.8666666746139526], 'test/multiclassaccuracy_3': [0.8409090638160706, 0.8636363744735718, 0.7954545617103577, 0.9659090638160706, 0.9090909361839294], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.8831775784492493, 0.8457943797111511, 0.8785046935081482, 0.8971962332725525, 0.8925233483314514], 'test/Multiclass_Jaccard_Index': [0.6442587375640869, 0.643195927

In [43]:
# print total results of 6 bands  
print(results)
for param in results.keys():
    print(f'Average {param}= {np.mean(results[param])}+={np.std(results[param])}')

{'test/loss': [0.45407089591026306, 0.5348568558692932, 0.31282293796539307, 0.3503466844558716, 0.2745228409767151], 'test/Average_Accuracy': [0.7788050174713135, 0.7543442845344543, 0.7796567678451538, 0.777133584022522, 0.7790104150772095], 'test/multiclassaccuracy_0': [0.9152542352676392, 0.7457627058029175, 0.9491525292396545, 0.8135592937469482, 0.8983050584793091], 'test/multiclassaccuracy_1': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/multiclassaccuracy_2': [0.9166666865348816, 0.9166666865348816, 0.9333333373069763, 0.8833333253860474, 0.8666666746139526], 'test/multiclassaccuracy_3': [0.8409090638160706, 0.8636363744735718, 0.7954545617103577, 0.9659090638160706, 0.9090909361839294], 'test/multiclassaccuracy_4': [0.0, 0.0, 0.0, 0.0, 0.0], 'test/multiclassaccuracy_5': [1.0, 1.0, 1.0, 1.0, 1.0], 'test/Multiclass_F1_Score': [0.8831775784492493, 0.8457943797111511, 0.8785046935081482, 0.8971962332725525, 0.8925233483314514], 'test/Multiclass_Jaccard_Index': [0.6442587375640869, 0.643195927

### Processing the segmentation masks into the correct folders