In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline  

In [None]:
path_to_dataset = "../../public_data"

In [None]:
import pathlib

import cdmetadl.helpers.general_helpers
import cdmetadl.dataset

dataset_path = pathlib.Path(path_to_dataset)
dataset_info_dict = cdmetadl.helpers.general_helpers.check_datasets(dataset_path, ["BRD"])

dataset = cdmetadl.dataset.ImageDataset("Birds", dataset_info_dict["BRD"])

In [None]:
import cdmetadl.samplers

n_way_sampler = cdmetadl.samplers.ValueSampler(value=10)
k_shot_sampler = cdmetadl.samplers.ValueSampler(value=4)

task = dataset.generate_task(n_way_sampler, k_shot_sampler, query_size=4)

In [None]:
import cdmetadl.augmentation
import cdmetadl.notebooks.helpers
import numpy as np
import torch
import os
import plotly.io as pio
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from IPython.display import display

In [None]:
def create_plot(augmentor, task):
    augmented_set_generative = augmentor.augment(task.support_set, conf_scores=[0.1, 0.1, 0.1, 0.1, 0.1])

    fig = cdmetadl.notebooks.helpers.show_images_grid_plotly(task.support_set)
    fig.update_layout(title='Original data')
    file_path = "./figname_support_set.png"
    pio.write_image(fig, file_path)
    display(Image.open(file_path))
    os.remove(file_path)


    fig = cdmetadl.notebooks.helpers.show_images_grid_plotly(augmented_set_generative)
    fig.update_layout(title='Generative Augmented data')
    file_path = "./figname_augmented_data.png"
    pio.write_image(fig, file_path)
    display(Image.open(file_path))
    os.remove(file_path)

def generate_edge_map_plot(augmentor):
    augmentor.augment(task.support_set, conf_scores=[0.1, 0.1, 0.1, 0.1, 0.1])
    # Plotting images for each element in the list
    fig, axs = plt.subplots(len(augmentor.generated_images), 3, figsize=(15/2, 5*len(augmentor.generated_images)/2))

    for i, data in enumerate(augmentor.generated_images):

            axs[i, 0].imshow(data['original_image'])
            axs[i, 0].set_title('Original Image')
            axs[i, 0].axis('off')

            axs[i, 1].imshow(data['edge_map'])
            axs[i, 1].set_title('Edge Map')
            axs[i, 1].axis('off')

            axs[i, 2].imshow(data['generated_image'])
            axs[i, 2].set_title('Generated Image')
            axs[i, 2].axis('off')
    plt.show()

In [None]:
generative_augmentor = cdmetadl.augmentation.GenerativeAugmentation(threshold=0.75, scale=1,
                                                                    annotator_type="segmentation",
                                                                    keep_original_data=False,
                                                                    cache_images=True)

generate_edge_map_plot(generative_augmentor)

#<PIL.Image.Image image mode=RGB size=512x512 at 0x7F0FEDE0B810>
