In [1]:
from typing import Dict
from torch import nn

from remote_sensing_ddpm.downstream_tasks.modules.feature_extractor import FeatureExtractor
from remote_sensing_ddpm.downstream_tasks.downstream_task_model import DownstreamTaskModel
from remote_sensing_ddpm.train_downstream_tasks import fuse_backbone_and_downstream_head_config

In [2]:
from pprint import pprint
from typing import Any, Dict
from inspect import isfunction

from importlib import import_module
from lit_diffusion.constants import (
    PYTHON_CLASS_CONFIG_KEY,
    PYTHON_ARGS_CONFIG_KEY,
    PYTHON_KWARGS_CONFIG_KEY,
    INSTANTIATE_DELAY_CONFIG_KEY,
    CALL_UPON_INSTANTIATION_KEY,
)

_POSSIBLE_ARGS_CONFIG_KEYS = [
    PYTHON_ARGS_CONFIG_KEY,
    PYTHON_KWARGS_CONFIG_KEY,
    INSTANTIATE_DELAY_CONFIG_KEY,
    CALL_UPON_INSTANTIATION_KEY,
]

def instantiate_python_class_from_string_config(
    class_config: Dict,
    verbose: bool = False,
    **kwargs,
):
    # Assert that necessary keys are contained in config
    assert isinstance(class_config, Dict), f"{class_config} is not a dictionary."
    assert (
        PYTHON_CLASS_CONFIG_KEY in class_config.keys()
    ), f"Expected key {PYTHON_CLASS_CONFIG_KEY} but got keys: {', '.join(class_config.keys())}"

    def recursive_call_with_check(possible_config_dict: Any):
        # If a parameters is a dictionary...
        if isinstance(possible_config_dict, Dict):
            keys = set(possible_config_dict.keys())
            # ... delay instantiation to a later call if desired ...
            if INSTANTIATE_DELAY_CONFIG_KEY in keys:
                if possible_config_dict[INSTANTIATE_DELAY_CONFIG_KEY] > 0:
                    possible_config_dict[INSTANTIATE_DELAY_CONFIG_KEY] -= 1
                    return possible_config_dict
            # ... check if it is a valid instantiation config ...
            valid_config_key_sets = [{PYTHON_CLASS_CONFIG_KEY}]
            for idx in range(len(_POSSIBLE_ARGS_CONFIG_KEYS)):
                for jdx in reversed(
                    range(idx + 1, len(_POSSIBLE_ARGS_CONFIG_KEYS) + 1)
                ):
                    valid_config_key_sets.append(
                        {PYTHON_CLASS_CONFIG_KEY, *_POSSIBLE_ARGS_CONFIG_KEYS[idx:jdx]}
                    )
            if any(keys == subset for subset in valid_config_key_sets):
                # ... and if so instantiate the python object.
                return instantiate_python_class_from_string_config(
                    class_config=possible_config_dict,
                    verbose=verbose,
                )

            # Check all levels of dict
            for config_keys, config_values in possible_config_dict.items():
                possible_config_dict[config_keys] = recursive_call_with_check(
                    config_values
                )
        # If parameters is a list
        elif isinstance(possible_config_dict, list):
            # check all entries of list
            for idx, config_values in enumerate(possible_config_dict):
                possible_config_dict[idx] = recursive_call_with_check(config_values)
        return possible_config_dict

    # Recursively instantiate any further required python objects
    # ...for regular arguments
    class_args = class_config.get(PYTHON_ARGS_CONFIG_KEY, list())
    for i, v in enumerate(class_args):
        class_args[i] = recursive_call_with_check(v)
    # ...for key-word arguments
    class_kwargs = class_config.get(PYTHON_KWARGS_CONFIG_KEY, dict())
    for k, v in class_kwargs.items():
        class_kwargs[k] = recursive_call_with_check(v)

    # Get module and class names
    module_full_name: str = class_config[PYTHON_CLASS_CONFIG_KEY]
    module_sub_names = module_full_name.split(".")
    module_name = ".".join(module_sub_names[:-1])
    class_name = module_sub_names[-1]
    # Import specified module
    module = import_module(module_name)
    object_to_instantiate = getattr(module, class_name)

    # Python function call of the module attribute with specified config values
    if verbose:
        print(f"Instantiating {class_name} with the following arguments:")
        pprint(
            class_config
        )
    # Give user the option to call it upon instantiation
    if class_config.get(CALL_UPON_INSTANTIATION_KEY, True):
        object_to_instantiate = object_to_instantiate(
            *class_args,
            **class_kwargs,
            **kwargs,
        )
    return object_to_instantiate

In [3]:
import yaml

with open("../config/model_configs/downstream_tasks/feature_extractors/s2_s1.yaml", "r") as bcf:
    backbone_config = yaml.safe_load(bcf)
with open("../config/model_configs/downstream_tasks/tier_1/ewc-segmentation.yaml", "r") as dcf:
    downstream_config = yaml.safe_load(dcf)
    
complete_config = fuse_backbone_and_downstream_head_config(
    backbone_config=backbone_config,
    downstream_head_config=downstream_config,
)

In [5]:
model = instantiate_python_class_from_string_config(
    complete_config["pl_module"],
)

Instantiating CrossEntropyLoss with the following arguments:
{'class_args': [], 'class_kwargs': {}, 'kwargs': {}}
Instantiating Accuracy with the following arguments:
{'class_args': [],
 'class_kwargs': {'num_classes': 11, 'task': 'multiclass'},
 'kwargs': {}}
Instantiating TorchmetricsAdapter with the following arguments:
{'class_args': [],
 'class_kwargs': {'apply_argmax': True,
                  'device': 'cpu',
                  'torchmetrics_module': MulticlassAccuracy()},
 'kwargs': {}}
Instantiating JaccardIndex with the following arguments:
{'class_args': [],
 'class_kwargs': {'num_classes': 11, 'task': 'multiclass'},
 'kwargs': {}}
Instantiating TorchmetricsAdapter with the following arguments:
{'class_args': [],
 'class_kwargs': {'apply_argmax': True,
                  'device': 'cpu',
                  'torchmetrics_module': MulticlassJaccardIndex()},
 'kwargs': {}}
Instantiating Accuracy with the following arguments:
{'class_args': [],
 'class_kwargs': {'average': None, 'nu

In [7]:
import wandb

from pathlib import Path
from remote_sensing_ddpm.train_downstream_tasks import create_wandb_run_name, LABEL_FRACTION_PATHS

FEATURE_EXTRACTOR_FILES = Path("../config/model_configs/downstream_tasks/feature_extractors")
EWC_SEGMENTATION_CONFIG_PATH = Path("../../config/model_configs/downstream_tasks/tier_1/ewc-segmentation.yaml")

FEATURE_EXTRACTOR_NAMES = list(FEATURE_EXTRACTOR_FILES.glob("*.yaml"))
MODALITY_NAMES = [p.name.split(".")[0] for p in FEATURE_EXTRACTOR_NAMES]
EXPERIMENT_NAMES = [
    create_wandb_run_name(backbone_name=backbone_path.name, downstream_head_name=EWC_SEGMENTATION_CONFIG_PATH.name)
    for backbone_path in FEATURE_EXTRACTOR_NAMES
]
LABEL_FRACTION_EXPERIMENT_NAMES = [
    name + f"-lf-{fraction}"
    for name in EXPERIMENT_NAMES
    for fraction in LABEL_FRACTION_PATHS.keys()
]

In [13]:
api = wandb.Api()
RUN_FILTER = {"$and": [{"display_name": {"$in": EXPERIMENT_NAMES}}, {"state": {"$eq": "finished"}}]}
WANDB_PROJECT_NAME = "rs-ddpm-ms-segmentation-egypt"
runs = api.runs(f"ssl-diffusion/{WANDB_PROJECT_NAME}", filters=RUN_FILTER)

In [None]:
for backbone_path in FEATURE_EXTRACTOR

In [29]:
checkpoint_paths = {}
for run in runs:
    run_path = Path(f"../{WANDB_PROJECT_NAME}/{run.id}/checkpoints/last.ckpt")
    if run.name in checkpoint_paths.keys():
        checkpoint_paths[run.name].append(run_path)
    else:
        checkpoint_paths[run.name] = [run_path]
checkpoint_paths

{'s2_era5-ewc-segmentation': [PosixPath('../rs-ddpm-ms-segmentation-egypt/izp0czwp/checkpoints/last.ckpt'),
  PosixPath('../rs-ddpm-ms-segmentation-egypt/k36ol1zh/checkpoints/last.ckpt'),
  PosixPath('../rs-ddpm-ms-segmentation-egypt/hbubxaq8/checkpoints/last.ckpt'),
  PosixPath('../rs-ddpm-ms-segmentation-egypt/luca7w95/checkpoints/last.ckpt'),
  PosixPath('../rs-ddpm-ms-segmentation-egypt/klyfmsqg/checkpoints/last.ckpt')],
 's2_climate_zones-ewc-segmentation': [PosixPath('../rs-ddpm-ms-segmentation-egypt/m0kxo0rv/checkpoints/last.ckpt'),
  PosixPath('../rs-ddpm-ms-segmentation-egypt/19nu3sor/checkpoints/last.ckpt'),
  PosixPath('../rs-ddpm-ms-segmentation-egypt/owgmlaed/checkpoints/last.ckpt'),
  PosixPath('../rs-ddpm-ms-segmentation-egypt/y90jylya/checkpoints/last.ckpt'),
  PosixPath('../rs-ddpm-ms-segmentation-egypt/1rtftjvy/checkpoints/last.ckpt')],
 's2_glo_30_dem-ewc-segmentation': [PosixPath('../rs-ddpm-ms-segmentation-egypt/yoyj3tbd/checkpoints/last.ckpt'),
  PosixPath('../rs-

In [30]:
import torch

In [61]:
smpl = torch.load("../rs-ddpm-ms-regression-egypt/9kca3uap/checkpoints/last.ckpt")

In [62]:
smpl["callbacks"]["ModelCheckpoint{'monitor': 'val/mean_squared_error', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 5, 'train_time_interval': None}"]

{'monitor': 'val/mean_squared_error',
 'best_model_score': tensor(0.0026, device='cuda:0'),
 'best_model_path': './rs-ddpm-ms-regression-egypt/9kca3uap/checkpoints/epoch=14-step=58230.ckpt',
 'current_score': tensor(0.0026, device='cuda:0'),
 'dirpath': './rs-ddpm-ms-regression-egypt/9kca3uap/checkpoints',
 'best_k_models': {'./rs-ddpm-ms-regression-egypt/9kca3uap/checkpoints/epoch=9-step=38820.ckpt': tensor(0.0027, device='cuda:0'),
  './rs-ddpm-ms-regression-egypt/9kca3uap/checkpoints/epoch=14-step=58230.ckpt': tensor(0.0026, device='cuda:0'),
  './rs-ddpm-ms-regression-egypt/9kca3uap/checkpoints/epoch=19-step=77640.ckpt': tensor(0.0026, device='cuda:0')},
 'kth_best_model_path': './rs-ddpm-ms-regression-egypt/9kca3uap/checkpoints/epoch=9-step=38820.ckpt',
 'kth_value': tensor(0.0027, device='cuda:0'),
 'last_model_path': './rs-ddpm-ms-regression-egypt/9kca3uap/checkpoints/last.ckpt'}

## Get best run for certain epochs ##

In [52]:
import wandb
from tqdm import tqdm
import pandas as pd

api = wandb.Api()

In [46]:
run = api.run("ssl-diffusion/rs-ddpm-ms-regression-egypt/9kca3uap")

In [64]:
keys_of_interest = ["epoch", "val/mean_squared_error"]
single_run_complete_history = []
# Get all data
for x in tqdm(run.scan_history(keys=keys_of_interest, page_size=10000), desc="Loading history"):
    single_run_complete_history.append(x)

Loading history: 20it [00:05,  3.38it/s]


In [74]:
getattr(history_df["val/mean_squared_error"], f"idxmin")()

14

In [71]:
history_df["val/mean_squared_error"].idxmin()

14

In [78]:
history_df = pd.DataFrame(single_run_complete_history)

available_epochs = [4, 9, 14, 19]
history_df = history_df.loc[available_epochs, :]
history_df["val/mean_squared_error"].idxmin()

14

In [77]:
history_df.loc[history_df["val/mean_squared_error"].idxmin(), :]["epoch"]

14.0

In [68]:
history_df

Unnamed: 0,epoch,val/mean_squared_error
0,0,0.003741
1,1,0.003247
2,2,0.003115
3,3,0.003046
4,4,0.00292
5,5,0.002872
6,6,0.002771
7,7,0.002766
8,8,0.0027
9,9,0.002698


In [82]:
import os

os.listdir("../rs-ddpm-ms-segmentation-egypt/h2vuc5rd/checkpoints")

['epoch=4-step=190.ckpt',
 'last.ckpt',
 'epoch=9-step=380.ckpt',
 'epoch=14-step=570.ckpt']