In [None]:
import albumentations as A
import numpy as np
import timm
import torch
from albumentations.pytorch import ToTensorV2
from torch.utils.data.dataloader import default_collate
from torchvision import transforms as tt
from torchvision.utils import draw_segmentation_masks


from ccb.experiment.experiment import Job, get_model_generator
from ccb.torch_toolbox.dataset import DataModule
from ccb.dataset_converters.inspect_tools import float_image_to_uint8, overlay_label

from ruamel.yaml import YAML
import yaml
from pathlib import Path
import pickle
import tempfile
import random
import matplotlib.pyplot as plt
import pandas as pd
from ccb import io
from ccb.io.dataset import Band
from matplotlib import cm

## Specify model and dataset

In [None]:
config_file_path = "/mnt/home/climate-change-benchmark/ccb/configs/classification_config.yaml"
task_specs_path = "/mnt/data/cc_benchmark/classification_v0.7/eurosat/task_specs.pkl"
task = "classification" if "classification" in task_specs_path else "segmentation"

with Path(config_file_path).open() as f:
    config = yaml.safe_load(f)

config["model"]["model_generator_module_name"] = "ccb.torch_toolbox.model_generators.timm_generator"
config["model"]["backbone"] = "resnet18"
config["model"]["encoder_type"] = "resnet18"
config["model"]["decoder_type"] = "Unet"

config["model"]["desired_input_size"] = 224
config["model"]["batch_size"] = 16
config["experiment"]["benchmark_dir"] = str(Path(task_specs_path).parents[1])


with open(task_specs_path, "rb") as fd:
    task_specs = pickle.load(fd)

num_samples_to_viz = 16


## Create temporary job directory

In [None]:
# create and fill experiment directory
temp_dir = tempfile.TemporaryDirectory()

temp_dir_path = temp_dir.name

job_dir = Path(temp_dir_path) / task_specs.dataset_name
job = Job(job_dir)
job.save_config(config)
job.save_task_specs(task_specs)

## Load DataModule and Model

In [None]:
model_gen = get_model_generator(config["model"]["model_generator_module_name"])

model = model_gen.generate_model(task_specs=job.task_specs, config=config)

train_ds = task_specs.get_dataset(
    split="train",
    partition_name=config["experiment"]["partition_name"],
    band_names=config["dataset"]["band_names"],
    format=config["dataset"]["format"],
    benchmark_dir=Path(task_specs_path).parents[1],
)

eval_ds = task_specs.get_dataset(
    split="valid",
    partition_name=config["experiment"]["partition_name"],
    band_names=config["dataset"]["band_names"],
    format=config["dataset"]["format"],
    benchmark_dir=Path(task_specs_path).parents[1],
)

## Define a transform function that you would like to test

#### Classification Transform

In [None]:
def check_transform_function_classification(task_specs, config, train=True):
    mean, std = task_specs.get_dataset(
            split="train",
            format=config["dataset"]["format"],
            band_names=tuple(config["dataset"]["band_names"]),
            benchmark_dir=config["experiment"]["benchmark_dir"],
            partition_name=config["experiment"]["partition_name"],
        ).normalization_stats()

    desired_input_size = config["model"]["default_input_size"][1]

    t = []
    t.append(tt.ToTensor())
    t.append(tt.Normalize(mean=mean, std=std))
    if train:
        t.append(tt.RandomHorizontalFlip())

    t.append(tt.Resize((desired_input_size, desired_input_size)))
    transform_comp = tt.Compose(t)

    def transform(sample):
        x: "np.typing.NDArray[np.float_]" = sample.pack_to_3d(band_names=config["dataset"]["band_names"])[0].astype(
            "float32"
        )
        print("before")
        print(x.shape)
        print((x[:,:, 0].mean(), x[:,:, 0].std()), (x[:,:,1].mean(), x[:,:, 1].std()), (x[:,:, 2].mean(), x[:,:, 2].std()))
        
        x = transform_comp(x)

        print("after")
        print(x.shape)
        print((x[0,:,:].mean().item(), x[0,:,:].std().item()), (x[1,:,:].mean().item(), x[1,:,:].std().item()), (x[2,:,:].mean().item(), x[2,:,:].std().item()))

        
        assert x.shape[1] in [224, 256]
        return {"input": x, "label": sample.label}

    return transform

#### Segmentation Transform Function

In [None]:
def check_transform_function_segmentation(task_specs, config, train=True):
    
    c, h, w = config["model"]["input_size"]
    patch_h, patch_w = task_specs.patch_size
    if h != w or patch_h != patch_w:
        raise (RuntimeError("Only square patches are supported in this version"))
    h32 = w32 = int(32 * (h // 32))  # make input res multiple of 32

    mean, std = task_specs.get_dataset(
        split="train",
        format=config["dataset"]["format"],
        band_names=tuple(config["dataset"]["band_names"]),
        benchmark_dir=config["experiment"]["benchmark_dir"],
        partition_name=config["experiment"]["partition_name"],
    ).normalization_stats()
    
    band_names = config["dataset"]["band_names"]

    t = []
    if h < patch_h:
        t.append(A.SmallestMaxSize(max_size=h))
    t.append(A.RandomCrop(h32, w32))
    if train:
        t.append(A.RandomRotate90(0.5))
        t.append(A.Flip())
    t.append(A.Normalize(mean=mean, std=std))
    t.append(ToTensorV2())
    t_comp = A.Compose(t)

    def transform(sample):
        x = sample.pack_to_3d(band_names=band_names)[0].astype("float32")

        if isinstance(sample.label, Band):
            x, y = x, sample.label.data.astype("float32")
            print("before")
            print(x.shape)
            print((x[:,:, 0].mean(), x[:,:, 0].std()), (x[:,:,1].mean(), x[:,:, 1].std()), (x[:,:, 2].mean(), x[:,:, 2].std()))
            transformed = t_comp(image=x, mask=y)
            print("after")
            a_img = transformed["image"]
            print(a_img.shape)
            print((a_img[0,:,:].mean().item(), a_img[0,:,:].std().item()), (a_img[1,:,:].mean().item(), a_img[1,:,:].std().item()), (a_img[2,:,:].mean().item(), a_img[2,:,:].std().item()))

        return {"input": transformed["image"], "label": transformed["mask"].long()}

    return transform


#### Select transform function

In [None]:
if task == "classification":
    check_train_transform = check_transform_function_classification(task_specs, config, train=True)
    check_eval_transform = check_transform_function_classification(task_specs, config, train=False)
elif task == "segmentation":
    check_train_transform = check_transform_function_segmentation(task_specs, config, train=True)
    check_eval_transform = check_transform_function_segmentation(task_specs, config, train=False)

### Get a batch from a Dataset

In [None]:
def collect_batch(ds, num_samples = num_samples_to_viz):
    before_transform = []
    after_transform = []
    for i in range(num_samples_to_viz):
        rand_idx = random.randint(0, len(ds))
        sample = ds[rand_idx]
        sample_array = sample.pack_to_3d(band_names=config["dataset"]["band_names"])[0].astype("float32")
        before_transform.append({"input": sample_array, "label": sample.label})
        transformed_sample = check_train_transform(sample)
        after_transform.append(transformed_sample)

    return before_transform, after_transform

## Funciton to visualize transformation

In [None]:
def visualize_classification(before, after):
    before_for_viz = float_image_to_uint8([b["input"] for b in before])
    after_for_viz = float_image_to_uint8([a["input"].permute(1, 2, 0).numpy() for a in after])

    fig, axs = plt.subplots(ncols=2, nrows=len(before), figsize=(20, 60))
    for idx, (b, a, b_z, a_z) in enumerate(zip(before, after, before_for_viz, after_for_viz)):
        # visualize before
        axs[idx, 0].imshow(b_z)
        axs[idx, 0].axis("off")

        if idx == 0:
            axs[idx, 0].set_title("Original")

        # visualize table that shows transformation values
        b_img = b["input"]
        a_img = a["input"]
        data = {
            'image_size': [b_img.shape[0], a_img.shape[2]], 
            'max_px': [b_img.max(), a_img.max().item()], 
            'min_px': [b_img.min(), a_img.min().item()],
            'mean_px': [b_img.mean(), a_img.mean().item()],
            'std_px': [b_img.std(), a_img.std().item()]
        }
        df = pd.DataFrame.from_dict(data, orient='index', columns=["before", "after"])

        # axs[idx, 1].table(cellText=df.values, colLabels=df.columns, rowLabels=df.index, loc='center')
        # axs[idx, 1].axis("off")
        
        # visualize after
        axs[idx, 1].imshow(a_z)
        if idx == 0:
            axs[idx, 1].set_title("After transform")
        # TODO convert classification numeric label into text label
        axs[idx, 1].axis("off")

    # fig.tight_layout()
    wspace = 0   # the amount of width reserved for blank space between subplots
    plt.subplots_adjust(wspace=wspace)
    plt.show()

def color_list(n_classes, background_id=0, background_color=(0, 0, 0)):
    colors = cm.hsv(np.linspace(0, 1, n_classes + 1))
    colors = colors[:, :-1]  # drop the last column since it corresponds to alpha channel.
    colors = colors[:-1]  # drop the last color since it's almost the same as the 1st color.
    colors[background_id, :] = background_color
    return colors

def visualize_segmentation(before, after):
    before_imgs = float_image_to_uint8([b["input"] for b in before])
    before_labels = [b["label"] for b in before]
    before_for_viz = [{"input": img, "label": label} for img, label in zip(before_imgs, before_labels)]

    after_imgs = float_image_to_uint8([a["input"].permute(1, 2,0).numpy() for a in after])
    after_labels = [a["label"].numpy() for a in after]
    after_for_viz = [{"input": img, "label": label} for img, label in zip(after_imgs, after_labels)]

    fig, axs = plt.subplots(ncols=3, nrows=len(before), figsize=(20, 60))
    for idx, (b, a, b_z, a_z) in enumerate(zip(before, after, before_for_viz, after_for_viz)):
        # visualize before
        img = overlay_label(b_z["input"], b_z["label"], label_patch_size=None, opacity=0.5).astype(int)
        axs[idx, 0].imshow(img)
        axs[idx, 0].axis("off")
        if idx == 0:
            axs[idx, 0].set_title("Before transform")

        # visualize table that shows transformation values
        b_img = b["input"]
        a_img = a["input"]
        data = {
            'image_size': [b_img.shape[0], a_img.shape[2]], 
            'max_px': [b_img.max(), a_img.max().item()], 
            'min_px': [b_img.min(), a_img.min().item()],
            'mean_px': [b_img.mean(), a_img.mean().item()],
            'std_px': [b_img.std(), a_img.std().item()]
        }
        df = pd.DataFrame.from_dict(data, orient='index', columns=["before", "after"])

        axs[idx, 1].table(cellText=df.values, colLabels=df.columns, rowLabels=df.index, loc='center')
        axs[idx, 1].axis("off")

        # visualize after
        img = torch.from_numpy(a_z["input"]).permute(2, 0, 1)

        label = torch.from_numpy(a_z["label"])
        colors = color_list(task_specs.label_type.n_classes)
        
        one_hot_label = torch.nn.functional.one_hot(label, num_classes=task_specs.label_type.n_classes).permute(2, 0, 1).bool()
        
        img = draw_segmentation_masks(img, one_hot_label, alpha=0.5)
        
        # axs[idx, 2].imshow(img.permute(1, 2, 0).numpy())
        axs[idx, 2].imshow(a_z["input"])
        axs[idx, 2].axis("off")
        if idx == 0:
            axs[idx, 2].set_title("After transform")

    plt.show()


def visualize_transforms(before, after, task):
    assert len(before) == len(after), "Each input needs transformed output."
    if task == "classification":
        visualize_classification(before, after)
    elif task == "segmentation":
        visualize_segmentation(before, after)
    else:
        raise ValueError("Invalid task, use 'classification' or 'segmentation'")



In [None]:
train_before, train_after = collect_batch(train_ds)
visualize_transforms(train_before, train_after, task)

In [None]:
# temp_dir.cleanup()