In [None]:
import os
from datetime import datetime
import logging
import json
import gc
import random

import pandas as pd
import tifffile as tiff
import numpy as np
from tqdm.notebook import tqdm, tqdm_notebook

from flame import FLAMEImage
from flame.error import FLAMEImageError

In [None]:
INPUT_DIREC = "/mnt/d/data/raw"
OUTPUT_DIREC = "/mnt/d/data/processed"
DATASET_DIREC = os.path.join(os.getcwd(), "datasets")
DS_TYPE = "denoising"
INPUT_N_FRAMES = 5
OUTPUT_N_FRAMES = 40

In [None]:
logger = logging.getLogger("main")
logging.basicConfig(
    filename=f"{datetime.now().strftime('%Y%m%d-%H%M%S')}_logger.log",
    encoding="utf-8",
    level=logging.DEBUG
)

### Find images to be used in the dataset

In [None]:
IMAGE_INDEX_PATH = os.path.join(DATASET_DIREC, "raw_image_index.csv")
assert os.path.isfile(IMAGE_INDEX_PATH), f"Image index not found at {IMAGE_INDEX_PATH}"
IMAGE_INDEX = pd.read_csv(IMAGE_INDEX_PATH)
IMAGE_INDEX.head()

In [None]:
PREINDEXED_IMAGES = IMAGE_INDEX['image']

In [None]:
this_ds = {}
for idx, relpath in tqdm(
        zip(IMAGE_INDEX['id'], IMAGE_INDEX['image']),
        ascii=True,
        unit="image",
        total=len(IMAGE_INDEX)
    ):
    this_impath = os.path.join(INPUT_DIREC, relpath)
    if os.path.isfile(this_impath):
        logger.info(f"Found image of id {idx} at {this_impath}")
        try:
            this_image = FLAMEImage(this_impath, "tileData.txt")
        except FLAMEImageError as e: # skipping those that could not be initialized for any reasonj
            logger.error(f"Could not initialize image of id {idx} at {this_impath}")
            continue
        if this_image.tileData.framesPerTile < OUTPUT_N_FRAMES:
            logger.warning(f"Skipping image of id {idx} at {this_impath} do to insufficient framesPerTile ({this_image.tileData.framesPerTile} not {OUTPUT_N_FRAMES})")
            continue # skipping those without OUTPUT_N_FRAMES frames
        this_ds[idx] = this_image
    else:
        logger.error(f"Could not find image of id {idx} at {this_impath}")

In [None]:
logger.info(f"Found {len(this_ds)} images for dataset")
print(f"Found {len(this_ds)} images for dataset")
if len(this_ds) < 1:
    logger.error(f"No valid images were found in {INPUT_DIREC}")
    raise Exception(f"No valid images were found in {INPUT_DIREC}")

### Documenting Dataset

In [None]:
DS_NAME = f"{datetime.now().strftime('%Y%m%d')}_{len(this_ds)}I_{DS_TYPE}_{INPUT_N_FRAMES}to{OUTPUT_N_FRAMES}F"

In [None]:
dataset_json = {
    "FLAME_Dataset": {
        "name": DS_NAME,
        "type": DS_TYPE,
        "image_shapes": [],
        "input": {
            "n_frames": INPUT_N_FRAMES,
            "pixel_mean": None,
            "pixel_min": None,
            "pixel_max": None,
            "pixel_p1pct": None,
            "pixel_1pct": None,
            "pixel_5pct": None,
            "pixel_95pct": None,
            "pixel_99pct": None,
            "pixel_99p9pct": None,
            "pixel_std": None,
        },
        "output": {
            "n_frames": OUTPUT_N_FRAMES,
            "pixel_mean": None,
            "pixel_min": None,
            "pixel_max": None,
            "pixel_p1pct": None,
            "pixel_1pct": None,
            "pixel_5pct": None,
            "pixel_95pct": None,
            "pixel_99pct": None,
            "pixel_99p9pct": None,
            "pixel_std": None,
        },
        "image_ids": list(this_ds.keys()),
        "test_ids": random.choices(list(this_ds.keys()), k=int(0.1*len(this_ds)))
    }
}

In [None]:
print(f"There are {len(dataset_json['FLAME_Dataset']['image_ids'])} images in the dataset")
logger.info(f"There are {len(dataset_json['FLAME_Dataset']['image_ids'])} images in the dataset")
print(f"There is a {len(dataset_json['FLAME_Dataset']['test_ids'])}-image testing subset")
logger.info(f"There is a {len(dataset_json['FLAME_Dataset']['test_ids'])}-image testing subset")

### Creating Dataset

In [None]:
DS_OUTPUT_DIREC = os.path.join(OUTPUT_DIREC, DS_NAME)
os.makedirs(DS_OUTPUT_DIREC, exist_ok=True)
logger.info(f"Created dataset output directory at {DS_OUTPUT_DIREC}")

In [None]:
TRAIN_DIREC = os.path.join(DS_OUTPUT_DIREC, "train")
TEST_DIREC = os.path.join(DS_OUTPUT_DIREC, "test")
os.makedirs(TRAIN_DIREC, exist_ok=True)
os.makedirs(TEST_DIREC, exist_ok=True)
logger.info(f"Created training data directory at {TRAIN_DIREC}")
logger.info(f"Created testing data directory at {TEST_DIREC}")

##### Assert that all dtypes are the same

This is important for the purpose of calculating statistics about the dataset. If all of the dtypes are the same, then the pixel-level intensity statistics can be calculated across the whole dataset, which is currently supported.

If there are many different dtypes across the dataset, then sub-datasets will have to be created for each datatype, with pixel-level intensity statistics being calculated for each sub-dataset for normalization. Once each sub-dataset is normalized according to its own pixel intensity statistics, then they sub-datasets can be recombined into a larger dataset. *THIS IS NOT CURRENTLY SUPPORTED*

In [None]:
FLAME_Images = list(this_ds.values())
adtype = FLAME_Images[0].imDType
for this_fl_im in FLAME_Images:
    assert this_fl_im.imDType == adtype

##### Create and save input an output images for low frame count and high frame count

In [None]:
input_frames_paths = []
output_frames_paths = []
all_input_pixels = None
all_output_pixels = None
TEST_IDS = dataset_json['FLAME_Dataset']['test_ids']
for idx, flame_im in tqdm(
        this_ds.items(),
        ascii=True,
        unit="image",
        total=len(this_ds)
    ):
    flame_im.openImage() # load image data into memory first

    # getting input frames by summing 0 to INPUT_N_FRAMES
    input_frames_path = os.path.join(
        TEST_DIREC if idx in TEST_IDS else TRAIN_DIREC, 
        f"id{idx}_frames{INPUT_N_FRAMES}.tif"
    )
    input_frames_paths.append(input_frames_path)
    input_frames = flame_im.get_frames(0, INPUT_N_FRAMES)
    logger.info(f"Saving {input_frames_path}...")
    tiff.imwrite(input_frames_path, input_frames)

    # adding shapes to dataset json if not present yet
    if input_frames.shape not in dataset_json['FLAME_Dataset']['image_shapes']:
        dataset_json['FLAME_Dataset']['image_shapes'].append(input_frames.shape)
    
    # getting output frames by summing 0 to OUTPUT_N_FRAMES
    output_frames_path = os.path.join(
        TEST_DIREC if idx in TEST_IDS else TRAIN_DIREC,
        f"id{idx}_frames{OUTPUT_N_FRAMES}.tif"
    )
    output_frames_paths.append(output_frames_path)
    output_frames = flame_im.get_frames(0, OUTPUT_N_FRAMES)
    logger.info(f"Saving {output_frames_path}...")
    tiff.imwrite(output_frames_path, output_frames)

    # adding pixels in this image to large whole-dataset array
    nchannels = len(flame_im.tileData.channelsSaved)
    if all_input_pixels is None:
        all_input_pixels = input_frames.reshape(nchannels, -1)
        all_output_pixels = output_frames.reshape(nchannels, -1)
    else:
        all_input_pixels = np.concat([all_input_pixels, input_frames.reshape(nchannels, -1)], axis=-1)
        all_output_pixels = np.concat([all_output_pixels, output_frames.reshape(nchannels, -1)], axis=-1)

    flame_im.closeImage() # close image and force garbage collection for memory management

In [None]:
for name, arr in zip(["input", "output"], [all_input_pixels, all_output_pixels]):
    dataset_json['FLAME_Dataset'][name]['pixel_p1pct'] = np.percentile(arr, 0.1, axis= 1).tolist()
    dataset_json['FLAME_Dataset'][name]['pixel_1pct'] = np.percentile(arr, 1, axis = 1).tolist()
    dataset_json['FLAME_Dataset'][name]['pixel_5pct'] = np.percentile(arr, 5, axis = 1).tolist()
    dataset_json['FLAME_Dataset'][name]['pixel_95pct'] = np.percentile(arr, 95, axis = 1).tolist()
    dataset_json['FLAME_Dataset'][name]['pixel_99pct'] = np.percentile(arr, 99, axis = 1).tolist()
    dataset_json['FLAME_Dataset'][name]['pixel_99p9pct'] = np.percentile(arr, 99.9, axis = 1).tolist()
    dataset_json['FLAME_Dataset'][name]['pixel_min'] = int(np.min(arr))
    dataset_json['FLAME_Dataset'][name]['pixel_max'] = int(np.max(arr))
    dataset_json['FLAME_Dataset'][name]['pixel_mean'] = float(np.mean(arr))
    dataset_json['FLAME_Dataset'][name]['pixel_std'] = float(np.std(arr))

del all_input_pixels
del all_output_pixels
gc.collect()


In [None]:
json_path = os.path.join(DATASET_DIREC, f"{DS_NAME}.json")
json.dump(dataset_json, open(json_path, "w+"))
logger.info(f"Saving dataset config JSON to {json_path}")

### Verification of processing

In [None]:
import numpy as np
import tifffile as tiff
from matplotlib import pyplot as plt
import random

In [None]:
indices = list(range(len(input_frames_paths)))
random.shuffle(indices)
choices = indices[:5]

In [None]:
fig = plt.figure(figsize = (8, 4 * len(choices)))
input_json, output_json = dataset_json['FLAME_Dataset']['input'], dataset_json['FLAME_Dataset']['output']
for idx, choice in enumerate(choices):
    left_ax = fig.add_subplot(len(choices), 2, 2 * idx + 1)
    right_ax = fig.add_subplot(len(choices), 2, 2 * idx + 2)
    
    # Normalizing low frame image
    low_frames = tiff.imread(input_frames_paths[choice]).transpose(1, 2, 0).astype(np.float64)
    low_lower, low_upper = np.array(input_json['pixel_1pct']), np.array(input_json['pixel_99pct'])
    low_frames = np.clip(low_frames, low_lower, low_upper)
    low_frames = (low_frames - low_lower) / (low_upper - low_lower)
    low_frames = low_frames.astype(np.float32)

    # Normalizing high frame image
    high_frames = tiff.imread(output_frames_paths[choice]).transpose(1, 2, 0).astype(np.float64)
    high_lower, high_upper = np.array(output_json['pixel_1pct']), np.array(output_json['pixel_99pct'])
    high_frames = np.clip(high_frames, high_lower, high_upper)
    high_frames = (high_frames - high_lower) / (high_upper - high_lower)
    high_frames = high_frames.astype(np.float32)

    # Visualizing
    left_ax.imshow(low_frames)
    right_ax.imshow(high_frames)

    imname = os.path.basename(input_frames_paths[choice])
    im_id = int(imname.split("_")[0].split("id")[1])
    im_path = IMAGE_INDEX.loc[IMAGE_INDEX['id'] == im_id, 'image'].iloc[0]
    im_name = "\n".join(im_path.split(os.path.sep)[-2:])
    left_ax.set_ylabel(im_name)

    if idx == 0:
        left_ax.set_title(f"{INPUT_N_FRAMES} frames")
        right_ax.set_title(f"{OUTPUT_N_FRAMES} frames")

plt.savefig(os.path.join(DATASET_DIREC, f"{DS_NAME}.png"))
