In [None]:
from geobench import io
import pandas as pd
from geobench.dataset_converters import inspect_tools
from geobench import io
from tqdm import tqdm
import numpy as np
from geobench.dataset_converters import inspect_tools
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [10, 6]
from pathlib import Path

%load_ext autoreload
%autoreload 2

In [None]:

def collect_task_info(task):
    loss = task.eval_loss
    if isinstance(loss, type):
        loss = loss()
    try:
        dataset = task.get_dataset(split='train')
        partition = dataset.active_partition.partition_dict
        n_train = len(partition["train"])
        n_valid = len(partition["valid"])
        n_test = len(partition["test"])
    except Exception as e:
        print(e)
        n_train, n_valid, n_test = -1, -1, -1
    
    n_classes = getattr(task.label_type, "n_classes", -1)

    task_dict = {
        'name': task.dataset_name,
        'img size': ' x '.join([ str(size) for size in task.patch_size]),
        'loss': str(loss),
        'label type': task.label_type.__class__.__name__,
        'n classes': int(n_classes),
        'n time steps': task.n_time_steps,
        'n train' : n_train,
        'n valid' : n_valid,
        'n test' : n_test,
    }
    task_dict.update(inspect_tools.summarize_band_info(task.bands_info))
    return task_dict, dataset

def collect_benchmark_info(benchmark_name):

    data = []
    for task in io.task_iterator(io.CCB_DIR / benchmark_name):
        print(task.dataset_name)

        task_dict, _ = collect_task_info(task)
        data.append(task_dict)
    return data



In [None]:
def extract_classification_samples(dataset: io.GeobenchDataset, num_samples=8, rng=np.random):
    label_map = dataset.task_specs.get_label_map()
    n_classes = len(label_map)
    n_per_class = np.ceil(num_samples / n_classes)
    samples = []
    for label, names in label_map.items():
        for sample_name in rng.choice(names, size=int(n_per_class), replace=False):
            samples.append(dataset.get_sample(sample_name))
    return samples[:num_samples]

def plot_images(images, names):
    fig, axs = plt.subplots(1,len(images))
    for image, name, ax in zip(images, names, axs):
        ax.imshow(image)
        ax.axis("off")
        ax.set_title(name)

def replace_str(name):
    replace_dict = {
        "Land principally occupied by agriculture, with significant areas of natural vegetation": "Ag. and vegetation",
        "Non-irrigated arable land": "Non-irrigated land",
        "Complex cultivation patterns": "Cultivation patterns",
        "Fruit trees and berry plantations": "Fruit trees and berry"}
    for key, val in replace_dict.items():
        name = name.replace(key, val)
    return name


n_samples = 4
for task in io.task_iterator(io.CCB_DIR / "classification_v0.5"):
    task_info, _ = collect_task_info(task)

    print(f"Task: {task.dataset_name}")
    print(f"sizes: train={task_info['n train']}.")
    print(f"RGB Shape: {task_info['img size']} ")

    dataset = task.get_dataset(split="train")
    
    if not isinstance(task.label_type, io.label.MultiLabelClassification):
        samples = extract_classification_samples(dataset, n_samples)
    else:
        samples = [dataset[i] for i in range(n_samples)]

    images, labels = inspect_tools.extract_images(samples)
    
    label_names = [ replace_str(task.label_type.value_to_str(label)) for label in labels]

    plot_images(images, label_names)
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    dir = Path("~/paper").expanduser()
    plt.savefig(dir / f"{task.dataset_name}.pdf", bbox_inches='tight')




In [None]:

# task_dicts = collect_benchmark_info("converted")
# column_order = ("name", "img size", "label type", 'n classes', 'n train', 'n valid', 'n test', 'n time steps', "Bands count", "Sentinel2 count", "RGB res", "NIR res", "HS res", "Elevation res")
# df = pd.DataFrame.from_records(task_dicts, index="name", columns=column_order)
# pd.set_option('max_colwidth', 300)

# df

In [None]:
task_dicts = collect_benchmark_info("segmentation_v0.2")
column_order = ("name", "img size", "label type", 'n classes', 'n train', 'n valid', 'n test', 'n time steps', "Bands count", "Sentinel2 count", "RGB res", "NIR res", "HS res", "Elevation res")
df = pd.DataFrame.from_records(task_dicts, index="name", columns=column_order)
pd.set_option('max_colwidth', 300)
df

In [None]:
task_dicts = collect_benchmark_info("classification_v0.5")
column_order = ("name", "img size", "label type", 'n classes', 'n train', 'n valid', 'n test', 'n time steps', "Bands count", "Sentinel2 count", "RGB res", "NIR res", "HS res", "Elevation res")
df = pd.DataFrame.from_records(task_dicts, index="name", columns=column_order)
pd.set_option('max_colwidth', 300)
df

In [None]:
task_dicts = collect_benchmark_info("converted")
column_order = ("name", "img size", "label type", 'n classes', 'n train', 'n valid', 'n test', 'n time steps', "Bands count", "Sentinel2 count", "RGB res", "NIR res", "HS res", "Elevation res")
df = pd.DataFrame.from_records(task_dicts, index="name", columns=column_order)
pd.set_option('max_colwidth', 300)
df