In [4]:
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio
import pandas as pd
from tqdm.auto import tqdm
from skimage import io

from FF.image_processing import thresh, prepare_scalemap, edgedetector, measure_D
from FF.fractal_generation import midpoint_displacement, mountainpro, branching_network


import warnings
warnings.filterwarnings("ignore", category=UserWarning)


In [5]:
def generate_coastline(D, iterations, P = 1):
        
        meshmap = midpoint_displacement(iterations, P = P, D=D)
        scalemap = prepare_scalemap(meshmap)
        threshmap, bwratio = thresh(scalemap, iterations)
        coastline = edgedetector(threshmap, iterations)
        return threshmap, coastline

def generate_mountain(D, iterations, P = 1, zslice = 0.5):
        meshmap = midpoint_displacement(iterations, P = P, D=D)
        scalemap = prepare_scalemap(meshmap)
        slice, mountain = mountainpro(scalemap, iterations, zslice)
        return (1-slice)*255, mountain * 255

def generate_branch_network(network_params, neuron_params):

    network = branching_network.generate_network(network_params=network_params, neuron_params=neuron_params)    
    network_masks = network.generate_binary_mask()
    network_mask_filled = network_masks['filled']
    network_mask_outline = network_masks['outline']

    return network_mask_filled, network_mask_outline

In [7]:
def batch_generate_fractals(num_fractals, iterations, D_range, P,
                            fractal_type='coastline',
                            output_dir='fractals_batch',
                            neuron_params=None, network_params=None,
                            zslice=0.5):
    """
    Generates multiple fractal images with varying D values and saves them in the specified directory.

    :param num_fractals: Number of fractals to generate.
    :param iterations: Number of iterations for the fractal generation algorithms.
    :param D_range: Tuple (min_D, max_D) defining the range of D values.
    :param P: The parameter controlling the randomness in the midpoint displacement.
    :param fractal_type: The type of fractal to generate ('coastline', 'mountain', 'branch network').
    :param output_dir: Directory where the generated fractal images will be saved.
    :param neuron_params: Dictionary of parameters for neuron generation (used for 'branch network' type).
    :param network_params: Dictionary of parameters for network generation (used for 'branch network' type).
    :param zslice: The z-slice value for mountain generation (used for 'mountain' type).
    """
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    min_D, max_D = D_range
    D_values = np.linspace(min_D, max_D, num_fractals)

    data = []

    for i, D in enumerate(tqdm(D_values, desc="Generating Fractals", unit="fractal")):

        if fractal_type == 'coastline':
            image, fractal = generate_coastline(D, iterations, P)

        elif fractal_type == 'mountain':
            image, fractal = generate_mountain(D, iterations, P, zslice)
            
        elif fractal_type == 'branch network':
            if neuron_params is None or network_params is None:
                return
            image, fractal = generate_branch_network(network_params, neuron_params)*255
            
        else:
            print('Invalid fractal type chosen, aborting.')
            return

        # Define the filename and save the fractal as a TIFF image
        #tiff_file = os.path.join(output_dir, f"{fractal_type}_fractal_{i}.tif")

        D_measured = measure_D(fractal, min_size = 8,max_size= np.shape(fractal)[0]//5, n_sizes=20, invert=False,)

        tiff_file = os.path.join(output_dir, f"{fractal_type}_fractal_{D_measured}.tif")
        imageio.imwrite(tiff_file, fractal.astype(np.uint8))  # Ensure the image is in uint8 format

        # Add the filename and D value to the list
        data.append([os.path.basename(tiff_file), D_measured, fractal_type])

    # Save the filenames and D values to a CSV file
    csv_file = os.path.join(output_dir, f"{fractal_type}_labels.csv")
    df = pd.DataFrame(data, columns=['filename', 'd_value', 'fractal_type'])
    df.to_csv(csv_file, index=False)

    print(f'Batch generation complete. {num_fractals} {fractal_type} fractals saved to {output_dir}.')
    print(f'Labels saved to {csv_file}.')

neuron_params = {
    'depth': 3,
    'mean_soma_radius': 60,
    'std_soma_radius': 15,
    'D': 1.2,
    'branch_angle': np.pi / 4,
    'mean_branches': 1.5,
    'weave_type': 'Gauss',
    'randomness': 0.2,
    'curviness': 'Gauss',
    'curviness_magnitude': 1.5,
    'n_primary_dendrites': 5,
}

network_params = {
    'width': 2048,
    'height': 2048,
    'num_neurons': 10,
    'network_id': 'test'
}

batch_generate_fractals(
    num_fractals=500,  # Adjust the number as needed
    iterations=8,     # Not used in branch network but required by the function signature
    D_range=(1.1, 1.7),  # D is not used directly in branch network generation
    P=0.5,               # Not used in branch network but required by the function signature
    fractal_type='branch network',
    output_dir='/home/apd/Projects/FractalFluency/datasets/data_dump',
    neuron_params=neuron_params,
    network_params=network_params
)


In [15]:
def batch_generate_fractals_multi_forced(num_fractals_per_type, 
                                         iterations, 
                                         D_ranges, 
                                         P,
                                         fractal_types=['coastline', 'mountain', 'branch network'],
                                         output_dir='/home/apd/Projects/FractalFluency/datasets/data_dump',
                                         neuron_params=None,
                                         network_params=None,
                                         zslice=0.5,
                                         tolerance=0.05,
                                         max_attempts=100,
                                         save_fractal=False):

    os.makedirs(output_dir, exist_ok=True)
    data = []
    skipped = 0

    for fractal_type in fractal_types:
        D_range = D_ranges.get(fractal_type, None)
        assert D_range is not None, f"D_range must be provided for fractal type '{fractal_type}', should be tuple (D_min, D_max), e.g. (1.2, 1.6)"
        
        min_D, max_D = D_range
        D_values = np.linspace(min_D, max_D, num_fractals_per_type)

        if fractal_type == 'coastline':
            for target_D in tqdm(D_values, desc=f"Generating {fractal_type} fractals", unit="fractal"):
                adjusted_D = target_D
                attempts = 0
                generated = 0
                
                while generated < 1 and attempts < max_attempts:
                    image, fractal = generate_coastline(adjusted_D, iterations, P)
                    D_measured = measure_D(fractal, min_size=8, max_size=np.shape(fractal)[0]//5, n_sizes=20, invert=False)
                    
                    if abs(D_measured - target_D) <= tolerance:
                        tiff_file = os.path.join(output_dir, f"{fractal_type}_image_{D_measured:.5f}.tif")
                        io.imsave(tiff_file, image.astype(np.uint8))
                        if save_fractal:
                            io.imsave(os.path.join(output_dir, f"{fractal_type}_fractal_{D_measured:.5f}.tif"), fractal.astype(np.uint8))
                        data.append([os.path.basename(tiff_file), D_measured, fractal_type])
                        generated += 1
                    else:
                        attempts += 1
                        if attempts % 5 == 0:
                            adjusted_D += 0.01 if D_measured < target_D else -0.01
                    
                    np.clip(adjusted_D, 1.0, 2.0)

                if attempts >= max_attempts and generated == 0:
                    skipped += 1

        elif fractal_type == 'mountain':
            for target_D in tqdm(D_values, desc=f"Generating {fractal_type} fractals", unit="fractal"):
                adjusted_D = target_D
                attempts = 0
                generated = 0

                while generated < 1 and attempts < max_attempts:
                    image, fractal = generate_mountain(adjusted_D, iterations, P, zslice)
                    D_measured = measure_D(fractal, min_size=16, max_size=np.shape(fractal)[0]//6, n_sizes=100, invert=False)

                    if abs(D_measured - target_D) <= tolerance:
                        tiff_file = os.path.join(output_dir, f"{fractal_type}_image_{D_measured:.5f}.tif")
                        io.imsave(tiff_file, image.astype(np.uint8))
                        if save_fractal:
                            io.imsave(os.path.join(output_dir, f"{fractal_type}_fractal_{D_measured:.5f}.tif"), fractal.astype(np.uint8))
                        data.append([os.path.basename(tiff_file), D_measured, fractal_type])
                        generated += 1
                    else:
                        attempts += 1
                        if attempts % 2 == 0:
                            adjusted_D += 0.02 if D_measured < target_D else -0.01

                    np.clip(adjusted_D, 1.0, 2.0)

                if attempts >= max_attempts and generated == 0:
                    print(f"skipped mountain fractal with target D: {target_D}" )
                    print(f"final adjusted D: {adjusted_D}")
                    skipped += 1

                

        elif fractal_type == 'branch network':
            for target_D in tqdm(D_values, desc=f"Generating {fractal_type} fractals", unit="fractal"):
                adjusted_D = target_D
                adjusted_mean_branches = neuron_params.get('mean_branches', 1.5)
                adjusted_n_primary_dendrites = neuron_params.get('n_primary_dendrites', 3)
                adjusted_num_neurons = network_params.get('num_neurons', 10)
                adjusted_branch_angle = neuron_params.get('branch_angle', np.pi / 4)
                adjusted_total_length = neuron_params.get('total_length', 400)
                attempts = 0
                generated = 0

                while generated < 1 and attempts < max_attempts:
                    neuron_params['D'] = adjusted_D
                    neuron_params['mean_branches'] = adjusted_mean_branches
                    neuron_params['n_primary_dendrites'] = adjusted_n_primary_dendrites
                    neuron_params['branch_angle'] = adjusted_branch_angle
                    neuron_params['total_length'] = adjusted_total_length
                    network_params['num_neurons'] = adjusted_num_neurons
                    

                    image, fractal = generate_branch_network(network_params, neuron_params)
                    fractal = (fractal * 255).astype(np.uint8)
                    image = (image * 255).astype(np.uint8)
                    D_measured = measure_D(fractal, min_size=8, max_size=np.shape(fractal)[0]//5, n_sizes=20, invert=False)

                    if abs(D_measured - target_D) <= tolerance:
                        tiff_file = os.path.join(output_dir, f"{fractal_type}_image_{D_measured:.5f}.tif")
                        io.imsave(tiff_file, image.astype(np.uint8))
                        if save_fractal:
                            io.imsave(os.path.join(output_dir, f"{fractal_type}_fractal_{D_measured:.5f}.tif"), fractal.astype(np.uint8))
                        data.append([os.path.basename(tiff_file), D_measured, fractal_type])
                        generated += 1
                    else:
                        attempts += 1
                        if attempts % 10 == 0:
                            if D_measured < target_D:
                                adjusted_D += 0.05
                                adjusted_mean_branches += 0.25
                                adjusted_n_primary_dendrites += 1
                                #adjusted_num_neurons += 1
                                adjusted_branch_angle += 5*np.pi / 180
                                adjusted_total_length += 50
                            else:
                                adjusted_D -= 0.05
                                adjusted_mean_branches -= 0.25
                                adjusted_n_primary_dendrites -= 1
                                #adjusted_num_neurons -= 1
                                adjusted_branch_angle -= 5*np.pi / 180
                                adjusted_total_length -= 50
                    
                    np.clip(adjusted_D, 1.0, 2.0)
                    np.clip(adjusted_total_length, 200, 800)
                    np.clip(adjusted_branch_angle, np.pi/8, np.pi/2)

                if attempts >= max_attempts and generated == 0:
                    skipped += 1

        csv_file = os.path.join(output_dir, 'labels.csv')
        df = pd.DataFrame(data, columns=['filename', 'd_value', 'fractal_type'])
        df.to_csv(csv_file, index=False)

    total_fractals = len(data)
    print(f'\nBatch generation complete, with {skipped} skipped. \n{total_fractals} fractals saved to {output_dir}.')
    print(f'Labels saved to {csv_file}.')


In [16]:
# Define the parameters for fractal generation
num_fractals_per_type = 10  # Number of fractals to generate per fractal type
iterations = 9              # Number of iterations for fractal generation algorithms
P = 1                       # Parameter controlling randomness in the generation
zslice = 0.5                # Z-slice value for mountain fractal generation
tolerance = 0.05            # Acceptable difference between measured D and target D
max_attempts = 100          # Maximum attempts to reach the target D for each fractal
output_dir = '/home/apd/Projects/FractalFluency/datasets/new_test'  # Directory to save generated fractals

# Define the fractal types you want to generate
fractal_types = ['branch network']

# Define the D_ranges for each fractal type
D_ranges = {
    'coastline': (1.2, 1.8),
    'mountain': (1.0, 1.6),
    'branch network': (1.2, 1.8)
}

# Define neuron parameters for 'branch network' fractals
neuron_params = {
    'depth': 5,
    'mean_soma_radius': 0,
    'std_soma_radius': 0,
    'D': None,
    'branch_angle': np.pi / 4,
    'mean_branches': 1.5,
    'weave_type': 'Gauss',
    'randomness': 0.2,
    'curviness': 'Gauss',
    'curviness_magnitude': 1.5,
    'n_primary_dendrites': 5,
    'total_length': 400,
    'initial_thickness': 20
}

# Define network parameters for 'branch network' fractals
network_params = {
    'num_neurons': 10,
    'width': 2048,
    'height': 2048,
    'edge_margin': 200
}

# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Call the batch_generate_fractals_multi_forced function
batch_generate_fractals_multi_forced(
    num_fractals_per_type=num_fractals_per_type,
    iterations=iterations,
    D_ranges=D_ranges,
    P=P,
    fractal_types=fractal_types,
    output_dir=output_dir,
    neuron_params=neuron_params,
    network_params=network_params,
    zslice=zslice,
    tolerance=tolerance,
    max_attempts=max_attempts,
    save_fractal = True
)


Generating branch network fractals:   0%|          | 0/10 [00:00<?, ?fractal/s]


Batch generation complete, with 0 skipped. 
10 fractals saved to /home/apd/Projects/FractalFluency/datasets/new_test.
Labels saved to /home/apd/Projects/FractalFluency/datasets/new_test/labels.csv.
