In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
path_to_dataset = "../../public_data"



In [None]:
import pathlib

import cdmetadl.helpers.general_helpers
import cdmetadl.dataset

dataset_path = pathlib.Path(path_to_dataset)
dataset_info_dict = cdmetadl.helpers.general_helpers.check_datasets(dataset_path, ["APL"])

dataset = cdmetadl.dataset.ImageDataset("Airplanes", dataset_info_dict["APL"])

In [None]:
import cdmetadl.samplers

n_way_sampler = cdmetadl.samplers.ValueSampler(value=5)
k_shot_sampler = cdmetadl.samplers.ValueSampler(value=4)

task = dataset.generate_task(n_way_sampler, k_shot_sampler, query_size=4)

In [None]:
import cdmetadl.augmentation
import cdmetadl.notebooks.helpers
import numpy as np
import torch
from PIL import Image
import cv2

In [None]:
def create_plot(augmentor, task):
    augmented_set_generative = augmentor.augment(task.support_set, conf_scores=[0.1, 0.1, 0.1, 0.1, 0.1])

    fig = cdmetadl.notebooks.helpers.show_images_grid_plotly(task.support_set)
    fig.update_layout(title='Original data')
    fig.show()
    fig = cdmetadl.notebooks.helpers.show_images_grid_plotly(augmented_set_generative)
    fig.update_layout(title='Generative Augmented data')
    fig.show()

## Canny Edge Detection + ControlNet Canny

In [None]:
generative_augmentor = cdmetadl.augmentation.GenerativeAugmentation(threshold=0.75, scale=1.0,
                                                                        keep_original_data=False)

def edge_detection(self, image):
    image = np.array(image) #512x512x3
    low_threshold = 100
    high_threshold = 200

    canny_image = cv2.Canny(image, low_threshold, high_threshold) #512x512
    canny_image = canny_image[:, :, None] #512x512x1
    canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) #512x512x3
    canny_image = Image.fromarray(canny_image)

    return canny_image

cdmetadl.augmentation.GenerativeAugmentation.edge_detection = edge_detection

generative_augmentor = cdmetadl.augmentation.GenerativeAugmentation(threshold=0.75, scale=1.0,
                                                                    diffusion_model_id="lllyasviel/sd-controlnet-canny",
                                                                        keep_original_data=False)

create_plot(generative_augmentor, task)

## Segmentation Edge Detection + ControlNet Seg

In [None]:
from cdmetadl.annotator.uniformer import UniformerDetector

apply_uniformer = UniformerDetector()

def edge_detection(self, image):
        with torch.no_grad():
                image = np.array(image)
                detected_map = apply_uniformer(image)
        return Image.fromarray(detected_map)


cdmetadl.augmentation.GenerativeAugmentation.edge_detection = edge_detection

generative_augmentor = cdmetadl.augmentation.GenerativeAugmentation(threshold=0.75, scale=1.0,
                                                                    diffusion_model_id="lllyasviel/sd-controlnet-seg",
                                                                        keep_original_data=False)


create_plot(generative_augmentor, task)

## HED Boundaries+ ControlNet HED

In [None]:
from cdmetadl.annotator.hed import HEDdetector

apply_hed = HEDdetector()

def edge_detection(self, image):
        with torch.no_grad():
                image = np.array(image)
                detected_map = apply_hed(image)
        return Image.fromarray(detected_map)


cdmetadl.augmentation.GenerativeAugmentation.edge_detection = edge_detection

generative_augmentor = cdmetadl.augmentation.GenerativeAugmentation(threshold=0.75, scale=1.0,
                                                                    diffusion_model_id="lllyasviel/sd-controlnet-hed",
                                                                        keep_original_data=False)


create_plot(generative_augmentor, task)

## M-LSD Lines + ControlNet M-LSD

In [None]:
from cdmetadl.annotator.mlsd import MLSDdetector

apply_mlsd = MLSDdetector()

def edge_detection(self, image):
        with torch.no_grad():
                value_threshold = 0.1
                distance_threshold = 0.1
                image = np.array(image)
                detected_map = apply_mlsd(image, value_threshold, distance_threshold)
        return Image.fromarray(detected_map)


cdmetadl.augmentation.GenerativeAugmentation.edge_detection = edge_detection

generative_augmentor = cdmetadl.augmentation.GenerativeAugmentation(threshold=0.75, scale=1.0,
                                                                    diffusion_model_id="lllyasviel/sd-controlnet-mlsd",
                                                                        keep_original_data=False)


create_plot(generative_augmentor, task)

## Midas Depth Maps + ControlNet Depth

In [None]:
from cdmetadl.annotator.midas import MidasDetector

apply_midas = MidasDetector()

def edge_detection(self, image):
        with torch.no_grad():
                image = np.array(image)
                detected_map = apply_midas(image)
        
        return Image.fromarray(detected_map[1])


cdmetadl.augmentation.GenerativeAugmentation.edge_detection = edge_detection

generative_augmentor = cdmetadl.augmentation.GenerativeAugmentation(threshold=0.75, scale=1.0,
                                                                    diffusion_model_id="lllyasviel/sd-controlnet-depth",
                                                                        keep_original_data=False)

create_plot(generative_augmentor, task)

## Midas Depth Maps + ControlNet Normal

In [None]:
from cdmetadl.annotator.midas import MidasDetector

apply_midas = MidasDetector()

def edge_detection(self, image):
        with torch.no_grad():
                image = np.array(image)
                detected_map = apply_midas(image)
        
        return Image.fromarray(detected_map[1])


cdmetadl.augmentation.GenerativeAugmentation.edge_detection = edge_detection


generative_augmentor = cdmetadl.augmentation.GenerativeAugmentation(threshold=0.75, scale=1.0,
                                                                    diffusion_model_id="lllyasviel/sd-controlnet-normal",
                                                                        keep_original_data=False)


create_plot(generative_augmentor, task)