### Collect relevant paths

In [705]:
import re
import glob
from aicsimageio import AICSImage
import numpy as np
import cv2
from scipy import ndimage
from PIL import Image
import os
import re
from natsort import natsorted
import matplotlib.pyplot as plt
from skimage import io, img_as_float32

In [706]:
path = directory_path
img_files = glob.glob(path + '*z*')
ch1_files = glob.glob(path + '*Nuc*')
ch2_files = glob.glob(path + '*Mito*')
ch3_files = glob.glob(path + '*Actin*')
ch4_files = glob.glob(path + '*Tub*')

In [707]:
def tile_image(image, patch_size=(256, 256), overlap=(0, 0)):
    if image is None:
        print("Error: Input image is None")
        return

    if len(image.shape) != 3:
        print("Error: Input image should be a 3-dimensional array (height x width x channels)")
        return
        
    height, width = image.shape[-2:]
    patch_height, patch_width = patch_size
    overlap_height, overlap_width = overlap

    patches = []

    # Calculate the number of patches in both dimensions
    num_patches_height = (height - overlap_height) // (patch_height - overlap_height) + 1
    num_patches_width = (width - overlap_width) // (patch_width - overlap_width) + 1

    # Iterate over each patch position
    for i in range(num_patches_height):
        for j in range(num_patches_width):
            # Calculate patch start and end positions
            
            
            start_h = i * (patch_height - overlap_height)
            end_h = min(start_h + patch_height, height)
            if(end_h - start_h < 256):
                start_h = (end_h - 256)
            
            
            start_w = j * (patch_width - overlap_width)
            end_w = min(start_w + patch_width, width)
            if(end_w - start_w < 256):
                start_w = (end_w - 256)
            

            # Extract the patch
            patch = image[:,start_h:end_h, start_w:end_w]

            # Append the patch to the list
            patches.append(patch)

    return patches

In [708]:
def natural_sort_key(s):
    """Key function for natural sorting."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

def min_max_normalize(image):
    min_val = np.min(image)
    max_val = np.max(image)
    normalized_image = (image - min_val) / (max_val - min_val)
    return normalized_image
    
def z_score_normalize(image):
    mean_val = np.mean(image)
    std_val = np.std(image)
    normalized_image = (image - mean_val) / std_val
    return normalized_image

def percentile_normalization(image, pmin=2, pmax=99.8, axis=None, dtype=np.uint16 ):
    '''
    Compute a percentile normalization for the given image.

    Parameters:
    - image (array): array of the image file.
    - pmin  (int or float): the minimal percentage for the percentiles to compute. 
                            Values must be between 0 and 100 inclusive.
    - pmax  (int or float): the maximal percentage for the percentiles to compute. 
                            Values must be between 0 and 100 inclusive.
    - axis : Axis or axes along which the percentiles are computed. 
             The default (=None) is to compute it along a flattened version of the array.
    - dtype (dtype): type of the wanted percentiles (uint16 by default)

    Returns:
    Normalized image (np.ndarray): An array containing the normalized image.
    '''

    if not (np.isscalar(pmin) and np.isscalar(pmax) and 0 <= pmin < pmax <= 100 ):
        raise ValueError("Invalid values for pmin and pmax")

    low_p  = np.percentile(image, pmin, axis=axis, keepdims=True)
    high_p = np.percentile(image, pmax, axis=axis, keepdims=True)

    if low_p == high_p:
        img_norm = image
        print(f"Same min {low_p} and high {high_p}, image may be empty")

    else:
        dtype_max = np.iinfo(dtype).max
        img_norm = dtype_max * (image - low_p) / (high_p - low_p)
        img_norm = img_norm.astype(dtype)

    return img_norm

In [709]:
img_files = sorted(img_files, key=natural_sort_key)
ch1_files = sorted(ch1_files, key=natural_sort_key)
ch2_files = sorted(ch2_files, key=natural_sort_key)
ch3_files = sorted(ch3_files, key=natural_sort_key)
ch4_files = sorted(ch4_files, key=natural_sort_key)

In [710]:
for file in img_files:
    print(file)

/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z0.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z1.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z2.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z3.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z4.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z5.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z6.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z7.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z8.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z9.ome.tiff
/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z10.ome.tiff
/media/local-admin/galaxy/lightmycells_sto

In [711]:
print(f"Image files: {len(img_files)}")
print(f"Nuc files: {len(ch1_files)}")
print(f"Mito files: {len(ch2_files)}")
print(f"Actin files: {len(ch3_files)}")
print(f"Tub files: {len(ch4_files)}")

Image files: 125
Nuc files: 10
Mito files: 10
Actin files: 0
Tub files: 0


### Identify missing targets from channels

In [712]:
channel_lists = [ch1_files, ch2_files, ch3_files, ch4_files]  

In [713]:
all_channels = []

for channel in channel_lists:
    all_channels.extend(channel)

In [714]:
all_channels

['/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_44_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_45_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_46_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_47_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_48_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_49_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_50_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_51_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_52_Nucleus.ome.tiff',
 '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_Mitocho

In [715]:
# Example list of image filenames
target_filenames = all_channels

# Regular expression pattern to match 'image_xy'
pattern = r"image_\d+"

# Extract unique 'image_xy' identifiers
unique_target_identifiers = natsorted(set(re.findall(pattern, ' '.join(target_filenames))))

max_target_count = len(unique_target_identifiers)

# Print unique 'image_xy' identifiers
print("Unique target identifiers:")
for identifier in unique_target_identifiers:
    print(identifier)
print(max_target_count)

Unique target identifiers:
image_43
image_44
image_45
image_46
image_47
image_48
image_49
image_50
image_51
image_52
10


In [716]:
channel_names = ['Nuc', 'Mito', 'Actin', 'Tub']
missing_index_dict = {}

for i,channel in enumerate(channel_lists):
    if len(channel) < max_target_count:
        # Example list of image filenames
        image_filenames = img_files
    
        # Example list of target filenames
        target_filenames = channel
        
        # Regular expression pattern to match 'image_xy'
        pattern = r"image_\d+"
        
        # Extract unique 'image_xy' identifiers from image filenames
        unique_image_identifiers = set(re.findall(pattern, ' '.join(image_filenames)))
        
        # Extract 'image_xy' identifiers from target filenames
        target_image_identifiers = set(re.findall(pattern, ' '.join(target_filenames)))
        
        # List to store identifiers not present in target filenames
        missing_identifiers = []
        
        # Check if unique 'image_xy' identifiers are present in target filenames
        for identifier in unique_image_identifiers:
            if identifier not in target_image_identifiers:
                missing_identifiers.append(identifier)
        
        print(f"Identifiers not present in target filenames for {channel_names[i]}:", missing_identifiers)
        print(f"Number of missing items: {len(missing_identifiers)}")
    
        missing_indexes = []
        #for identifier in missing_identifiers:
        for item in missing_identifiers:
            index_of_missing = unique_target_identifiers.index(item)
            missing_indexes.append(index_of_missing)
            print(f'Index of missing item: {index_of_missing}')
        missing_index_dict[channel_names[i]] = sorted(missing_indexes)

Identifiers not present in target filenames for Actin: ['image_52', 'image_51', 'image_48', 'image_49', 'image_43', 'image_45', 'image_44', 'image_46', 'image_50', 'image_47']
Number of missing items: 10
Index of missing item: 9
Index of missing item: 8
Index of missing item: 5
Index of missing item: 6
Index of missing item: 0
Index of missing item: 2
Index of missing item: 1
Index of missing item: 3
Index of missing item: 7
Index of missing item: 4
Identifiers not present in target filenames for Tub: ['image_52', 'image_51', 'image_48', 'image_49', 'image_43', 'image_45', 'image_44', 'image_46', 'image_50', 'image_47']
Number of missing items: 10
Index of missing item: 9
Index of missing item: 8
Index of missing item: 5
Index of missing item: 6
Index of missing item: 0
Index of missing item: 2
Index of missing item: 1
Index of missing item: 3
Index of missing item: 7
Index of missing item: 4


In [717]:
missing_index_dict

{'Actin': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 'Tub': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

### Collecting images

In [718]:
#for file in img_files:
    #img = AICSImage(file)
    #img_array = img.data
    #print(img_array.shape)

In [719]:
#unique_shapes = set()

#for file in img_files:
#    img = AICSImage(file)
#    img_array = img.data
#    shape = img_array.shape
#    unique_shapes.add(shape)

## Print the unique shapes
#for shape in unique_shapes:
    #print(shape)

In [720]:
# Dictionary to store filenames organized by image number
image_number_dict = {}

# Regular expression pattern to extract image numbers
pattern = r'image_(\d+)_\w+_z\d+\.ome\.tiff'

for filename in img_files:
    # Extract image number from filename
    match = re.search(pattern, filename)
    if match:
        image_number = int(match.group(1))
        # Add filename to the corresponding list based on image number
        if image_number in image_number_dict:
            image_number_dict[image_number].append(filename)
        else:
            image_number_dict[image_number] = [filename]

In [721]:
image_number_dict

{43: ['/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z0.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z1.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z2.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z3.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z4.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z5.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z6.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z7.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z8.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z9.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DI

In [722]:
# Regular expression pattern to extract image numbers
#pattern = r'image_(\d+)_\w+.ome\.tiff'

#for channel in channel_lists:
#    for filename in channel:
#        # Extract image number from filename
#        match = re.search(pattern, filename)
#        if match:
#            image_number = int(match.group(1))
#            # Add filename to the corresponding list based on image number
#            if image_number in image_number_dict:
#                image_number_dict[image_number].append(filename)
#            else:
#                image_number_dict[image_number] = [filename]

In [723]:
# Iterate through the dictionary and print the number of items for each key
num_layers = []
for key, value in image_number_dict.items():
    num_layers.append(len(value))
    print(f"Number of items assigned to key {key}: {len(value)}")

Number of items assigned to key 43: 12
Number of items assigned to key 44: 11
Number of items assigned to key 45: 9
Number of items assigned to key 46: 14
Number of items assigned to key 47: 12
Number of items assigned to key 48: 17
Number of items assigned to key 49: 15
Number of items assigned to key 50: 13
Number of items assigned to key 51: 12
Number of items assigned to key 52: 10


In [724]:
image_number_dict

{43: ['/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z0.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z1.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z2.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z3.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z4.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z5.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z6.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z7.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z8.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DIC_z9.ome.tiff',
  '/media/local-admin/galaxy/lightmycells_storage/images/Study_2/image_43_DI

In [725]:
img_data = []

for image_number, filenames in image_number_dict.items():
    
    # Iterate through the list of filenames for the current image number
    for filename in filenames:
        # Load image data using AICSImage
        img = AICSImage(filename)
        img_array = img.data
        #img_array = img_as_float32(img_array)
        img_array = np.squeeze(img_array)
        #img_array = cv2.resize(img_array, dsize=(resize_size[1], resize_size[0]), interpolation=cv2.INTER_CUBIC)
        #img_array = min_max_normalize(img_array)
        #img_array = z_score_normalize(img_array)
        #img_array = percentile_normalization(img_array)     !!!
        img_array = np.squeeze(img_array)
        # Append the image array to the list of image arrays for the current image number
        img_data.append(img_array)

In [726]:
img_array = np.array(img_data)
img_array = np.expand_dims(img_array, axis=1)

In [727]:
img_array.shape

(125, 1, 1024, 1024)

In [728]:
img_array.dtype

dtype('uint16')

In [729]:
len(img_array)

125

### Adding blank targets if they are missing

In [730]:
mask_stacks = []                                               #This is so if targets are missing, blanks are added in their place

for i,channel in enumerate(channel_lists):
    
    channel_list = []
    for mask in channel:
        mask_img = AICSImage(mask)
        mask_array = mask_img.data
        #mask_array = img_as_float32(mask_array)
        print(mask_array.dtype)
        #mask_array = np.squeeze(img_array)
        mask_array = np.squeeze(mask_array)
        #mask_array = cv2.resize(mask_array, dsize=resize_size)
        #mask_array = min_max_normalize(mask_array)                                     #change to float32, 
        #mask_array = z_score_normalize(mask_array)
        #mask_array = percentile_normalization(mask_array)
        channel_list.append(mask_array) 

    if channel_names[i] in missing_index_dict:
        #print(channel_names[i])
        # Get the indexes from the dictionary corresponding to the appropriate key
        zeros_array = np.zeros(img_array[0][0].shape).astype(np.uint16)                                         #!!!!!!!!!!!!!!!!!!!!!
        mito_indexes = missing_index_dict.get(channel_names[i], [])
        #print(mito_indexes)

        # Insert zeros_array at the indexes specified by mito_indexes in ch1_list
        for idx in mito_indexes:
            channel_list.insert(idx, zeros_array)
        
    mask_stacks.append(channel_list)

uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16


In [731]:
missing_index_dict

{'Actin': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 'Tub': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

In [732]:
for channel in mask_stacks:
    for i,x in enumerate(channel):
        if x.all() == 0:
            print(f'{i}: Replaced with blank')
    print(' ')

0: Replaced with blank
1: Replaced with blank
2: Replaced with blank
3: Replaced with blank
4: Replaced with blank
5: Replaced with blank
6: Replaced with blank
7: Replaced with blank
8: Replaced with blank
9: Replaced with blank
 
0: Replaced with blank
1: Replaced with blank
2: Replaced with blank
3: Replaced with blank
4: Replaced with blank
5: Replaced with blank
6: Replaced with blank
7: Replaced with blank
8: Replaced with blank
9: Replaced with blank
 
0: Replaced with blank
1: Replaced with blank
2: Replaced with blank
3: Replaced with blank
4: Replaced with blank
5: Replaced with blank
6: Replaced with blank
7: Replaced with blank
8: Replaced with blank
9: Replaced with blank
 
0: Replaced with blank
1: Replaced with blank
2: Replaced with blank
3: Replaced with blank
4: Replaced with blank
5: Replaced with blank
6: Replaced with blank
7: Replaced with blank
8: Replaced with blank
9: Replaced with blank
 


#### Replicating masks for each image

In [733]:
mask_data = np.array(mask_stacks)

In [734]:
mask_data.dtype

dtype('uint16')

In [735]:
mask_data = np.swapaxes(mask_data, 0,1)

In [736]:
mask_data.dtype

dtype('uint16')

In [737]:
mask_data.shape

(10, 4, 1024, 1024)

In [738]:
num_layers

[12, 11, 9, 14, 12, 17, 15, 13, 12, 10]

In [739]:
# Repeat each element in the first dimension of mask_data based on num_layers
extended_mask_data = np.repeat(mask_data, num_layers, axis=0)

# Verify the shape of the extended array
print("Shape of extended mask_data:", extended_mask_data.shape)

Shape of extended mask_data: (125, 4, 1024, 1024)


In [740]:
extended_mask_data.dtype

dtype('uint16')

In [741]:
mask_data = extended_mask_data

In [742]:
img_data = img_array

In [743]:
img_data[0].shape

(1, 1024, 1024)

In [744]:
mask_data[0].shape

(4, 1024, 1024)

In [745]:
mask_data.shape

(125, 4, 1024, 1024)

In [746]:
mask_data.dtype

dtype('uint16')

In [747]:
img_data.dtype

dtype('uint16')

### Shuffle the "stacks" so it's not always the last x% that gets split off (while respecting the stacks)

In [748]:
# Initialize lists to store the split data
split_img_data = []
split_mask_data = []

# Iterate over num_layers to split img_data and mask_data
start_idx = 0
for num in num_layers:
    end_idx = start_idx + num
    split_img_data.append(img_data[start_idx:end_idx])
    split_mask_data.append(mask_data[start_idx:end_idx])
    start_idx = end_idx

# Verify the split data
print("Data split according to stacks before shuffling:")
for i, (img, mask) in enumerate(zip(split_img_data, split_mask_data)):
    print(f"Stack {i+1}: Image data shape: {img.shape}, Mask data shape: {mask.shape}, Num layers: {num_layers[i]}")

 #Shuffle the data
shuffled_indices = np.random.permutation(len(split_img_data))
shuffled_img_data = [split_img_data[i] for i in shuffled_indices]
shuffled_mask_data = [split_mask_data[i] for i in shuffled_indices]
shuffled_num_layers = [num_layers[i] for i in shuffled_indices]

# Verify the shuffled data
print("Data split according to stacks after shuffling:")
for i, (img, mask) in enumerate(zip(shuffled_img_data, shuffled_mask_data)):
    print(f"Stack {i+1}: Image data shape: {img.shape}, Mask data shape: {mask.shape}, Num layers: {shuffled_num_layers[i]}")

# Reunite the shuffled data into a single array
reunited_shuffled_img_data = np.concatenate(shuffled_img_data, axis=0)
reunited_shuffled_mask_data = np.concatenate(shuffled_mask_data, axis=0)

# Verify the shape of the reunited shuffled data
print("Reunited Shuffled Image Data Shape:", reunited_shuffled_img_data.shape)
print("Reunited Shuffled Mask Data Shape:", reunited_shuffled_mask_data.shape)

img_data = reunited_shuffled_img_data
mask_data = reunited_shuffled_mask_data
num_layers = shuffled_num_layers

Data split according to stacks before shuffling:
Stack 1: Image data shape: (12, 1, 1024, 1024), Mask data shape: (12, 4, 1024, 1024), Num layers: 12
Stack 2: Image data shape: (11, 1, 1024, 1024), Mask data shape: (11, 4, 1024, 1024), Num layers: 11
Stack 3: Image data shape: (9, 1, 1024, 1024), Mask data shape: (9, 4, 1024, 1024), Num layers: 9
Stack 4: Image data shape: (14, 1, 1024, 1024), Mask data shape: (14, 4, 1024, 1024), Num layers: 14
Stack 5: Image data shape: (12, 1, 1024, 1024), Mask data shape: (12, 4, 1024, 1024), Num layers: 12
Stack 6: Image data shape: (17, 1, 1024, 1024), Mask data shape: (17, 4, 1024, 1024), Num layers: 17
Stack 7: Image data shape: (15, 1, 1024, 1024), Mask data shape: (15, 4, 1024, 1024), Num layers: 15
Stack 8: Image data shape: (13, 1, 1024, 1024), Mask data shape: (13, 4, 1024, 1024), Num layers: 13
Stack 9: Image data shape: (12, 1, 1024, 1024), Mask data shape: (12, 4, 1024, 1024), Num layers: 12
Stack 10: Image data shape: (10, 1, 1024, 102

In [749]:
img_data.dtype

dtype('uint16')

In [750]:
mask_data.dtype

dtype('uint16')

### Split into train, valid, test

#### Split train and test_t

In [751]:
# Example data
num_stacks = len(num_layers)

# Accumulate the count of images in each stack
stack_indices = np.cumsum(num_layers)

# Total number of images
total_images = sum(num_layers)

# Calculate the split indices
split_index = int(total_images * 0.9)  # Assuming 80% training and 20% testing, change this if you have more data

# Find the stack index at which the split occurs
stack_index = np.argmax(stack_indices >= split_index)

if ((num_stacks-1)-stack_index) < 1:                 # So there is at least 1 stack separated (which also means that the test data will be empty. Should this be 2 instead? But then ratio of data could be very different. Would also rather have empty test than validation              
    stack_index = stack_index-1
    print(stack_index)

index = stack_indices[stack_index]   #???

# Split the data into training and testing sets
train_data = img_data[:index]
train_targets = mask_data[:index]
test_data_t = img_data[index:]
test_targets_t = mask_data[index:]

# Print the shapes of the resulting datasets
print("Num stacks:", num_stacks)
print("Stack indices:", stack_indices)
print("Split index:", split_index)
print("Stack index:", stack_index)
print("Index:", index)
print("Training Image Data Shape:", train_data.shape)
print("Training Image Data Shape:", train_data.shape)
print("Training Mask Data Shape:", train_targets.shape)
print("Testing Image Data Shape:", test_data_t.shape)
print("Testing Mask Data Shape:", test_targets_t.shape)

8
Num stacks: 10
Stack indices: [ 12  24  33  46  57  72  84  94 108 125]
Split index: 112
Stack index: 8
Index: 108
Training Image Data Shape: (108, 1, 1024, 1024)
Training Image Data Shape: (108, 1, 1024, 1024)
Training Mask Data Shape: (108, 4, 1024, 1024)
Testing Image Data Shape: (17, 1, 1024, 1024)
Testing Mask Data Shape: (17, 4, 1024, 1024)


In [752]:
## Iterate through each image in train_targets
#for train_image_idx, train_image in enumerate(train_targets):
#    # Iterate through each image in test_targets_t
#    for test_image_idx, test_image in enumerate(test_targets_t):
#        # Check if the current train image is identical to the current test image
#        if np.array_equal(train_image, test_image):
#            print(f"Duplicate found! Train image index: {train_image_idx}, Test image index: {test_image_idx}")

In [753]:
train_data.dtype

dtype('uint16')

In [754]:
train_targets.dtype

dtype('uint16')

In [755]:
test_data_t.dtype

dtype('uint16')

In [756]:
test_targets_t.dtype

dtype('uint16')

#### Split into test and valid

In [757]:
num_layers_test = num_layers[stack_index+1:]

In [758]:
num_layers_test

[17]

In [759]:
num_stacks_test = len(num_layers_test)

# Accumulate the count of images in each stack
stack_indices_test = np.cumsum(num_layers_test)

# Total number of images
total_images_test = sum(num_layers_test)

# Calculate the split indices
split_index_test = int(total_images_test * 0.5)  # Assuming 80% training and 20% testing

# Find the stack index at which the split occurs
stack_index_test = np.argmax(stack_indices_test >= split_index_test)
    
index_test = stack_indices_test[stack_index_test]
print(index_test)

# Split the data into valid and testing sets
validation_data = test_data_t[:index_test]
validation_targets = test_targets_t[:index_test]
test_data = test_data_t[index_test:]
test_targets = test_targets_t[index_test:]

# Print the shapes of the resulting datasets
print("Valid Image Data Shape:", validation_data.shape)
print("Valid Mask Data Shape:", validation_targets.shape)
print("Test Image Data Shape:", test_data.shape)
print("Test Mask Data Shape:", test_targets.shape)

17
Valid Image Data Shape: (17, 1, 1024, 1024)
Valid Mask Data Shape: (17, 4, 1024, 1024)
Test Image Data Shape: (0, 1, 1024, 1024)
Test Mask Data Shape: (0, 4, 1024, 1024)


In [760]:
## Iterate through each image in train_targets
#for train_image_idx, train_image in enumerate(train_targets):
#    # Iterate through each image in test_targets_t
#    for test_image_idx, test_image in enumerate(test_targets):
#        # Check if the current train image is identical to the current test image
#        if np.array_equal(train_image, test_image):
#            print(f"Duplicate found! Train image index: {train_image_idx}, Test image index: {test_image_idx}")

In [761]:
validation_data.dtype

dtype('uint16')

In [762]:
validation_targets.dtype

dtype('uint16')

In [763]:
test_data.dtype

dtype('uint16')

In [764]:
test_targets.dtype

dtype('uint16')

Alright so I suppose depending on how many images a stack has it might not split it evenly exactly?

#### Old splitting (irrelevant)

##### Test set

In [765]:
#from sklearn.model_selection import train_test_split

In [766]:
#train_data_t, test_data, train_targets_t, test_targets = train_test_split(img_data, mask_data, test_size=0.05)

In [767]:
#train_data_t.shape

In [768]:
#train_targets_t.shape

In [769]:
#test_data.shape

In [770]:
#test_targets.shape

##### Validation set

In [771]:
#train_data, validation_data, train_targets, validation_targets = train_test_split(train_data_t, train_targets_t, test_size=0.05) #not really 5%, less

In [772]:
#train_data.shape

In [773]:
#train_targets.shape

In [774]:
#validation_data.shape

In [775]:
#validation_targets.shape

In [776]:
#test_data.shape

In [777]:
#test_targets.shape

### Data augmentation

In [778]:
def augment_images(train_images, target_images, augmentation_percentage=0.3):
    augmented_train_images = []
    augmented_target_images = []
    num_images_to_augment = int(len(train_images) * augmentation_percentage)
    indices_to_augment = np.random.choice(len(train_images), size=num_images_to_augment, replace=False)
    
    for idx, (train_img, target_img) in enumerate(zip(train_images, target_images)):
        if idx in indices_to_augment:
            augmented_train_img = train_img.copy()
            augmented_target_img = target_img.copy()
            random_var = np.random.rand()
            noise = np.random.normal(loc=0, scale=1, size=augmented_train_img[0].shape).astype(np.uint16)
            print(type(noise))
            print(augmented_train_img.dtype)
            print(augmented_target_img.dtype)
            
            for layer in range(augmented_train_img.shape[0]):                             
                # Apply flipping & noise
                if random_var <= 0.2:
                    augmented_train_img[layer] = np.flipud(augmented_train_img[layer])      # All of them will never happen together, right?
                elif 0.2 <= random_var <= 0.8:
                   augmented_train_img[layer] += noise         # Add Gaussian noise
                else:
                    augmented_train_img[layer] = np.fliplr(augmented_train_img[layer])       

            augmented_train_images.append(augmented_train_img)
            
            for layer in range(augmented_target_img.shape[0]):
                if np.all(augmented_target_img[layer] == 0):
                    continue
                    
                elif random_var <= 0.2:
                    augmented_target_img[layer] = np.flipud(augmented_target_img[layer])
                elif 0.2 <= random_var <= 0.8:
                    augmented_target_img[layer] += noise  # Add Gaussian noise
                else:
                    augmented_target_img[layer] = np.fliplr(augmented_target_img[layer])
                 
            augmented_target_images.append(augmented_target_img)

            print(augmented_train_img.dtype)
            print(augmented_target_img.dtype)

        
        
        augmented_train_images.append(train_img)          # keep the originals too
        augmented_target_images.append(target_img)
        print(train_img.dtype)
        print(target_img.dtype)
        
        
    
    return np.array(augmented_train_images), np.array(augmented_target_images)

In [779]:
train_data[0].shape

(1, 1024, 1024)

In [780]:
augmented_train_data, augmented_target_data = augment_images(train_data, train_targets, augmentation_percentage=0.25)

uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16
uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
<class 'numpy.ndarray'>
uint16
uint16
uint16
uint16


In [781]:
print(augmented_train_data.shape)
print(augmented_target_data.shape)

(135, 1, 1024, 1024)
(135, 4, 1024, 1024)


In [782]:
train_data = augmented_train_data
train_targets = augmented_target_data

In [783]:
print(train_data.dtype)
print(train_targets.dtype)

uint16
uint16


### Taking random sections (irrelevant, moved to training every few epochs)

In [784]:
#import numpy as np

In [785]:
#def extract_random_section(image, mask, section_size=(256, 256)):
#    # Get the shape of the image
#    image_height, image_width = image.shape[1:3]
    
#    # Calculate the maximum valid starting position
#    max_starting_height = image_height - section_size[0]
#    max_starting_width = image_width - section_size[1]
    
#    # Randomly select starting positions within the valid range
#    starting_height = np.random.randint(0, max_starting_height + 1)
#    starting_width = np.random.randint(0, max_starting_width + 1)
    
#    # Extract the section
#    section_image = image[:,starting_height:starting_height + section_size[0],
#                    starting_width:starting_width + section_size[1]]
#    section_mask = mask[:, starting_height:starting_height + section_size[0],
#                    starting_width:starting_width + section_size[1]]
    
#    return (section_image, section_mask)

# Example usage:
# Assuming 'image' is your input image array
# Randomly extract a 256x256 section from 'image'
#random_section1, random_mask1 = extract_random_section(img_data[0], mask_data[0])

In [786]:
#i_d = []
#m_d = []

In [787]:
#for img, mask in zip(train_data_t, train_targets_t):
#    for x in range(0, 20):
#        i, m = extract_random_section(img, mask)
#        i_d.append(i)
#        m_d.append(m)

In [788]:
#len(i_d)

In [789]:
#i_d[0].shape

In [790]:
#train_data_t = np.array([])  # Reassign an empty NumPy array
#train_targets_t = np.array([])

In [791]:
#train_data_t = np.array(i_d)
#train_targets_t = np.array(m_d)

In [792]:
#train_data_t.shape

In [793]:
#train_targets_t.shape

### Pytable

##### Create pytable (run only once)

See create_pytables.ipynb

### Adding data to pytable

In [794]:
import tables
import numpy as np

In [795]:
def batch_write_to_pytables(data_list, table, batch_size=1000):
    total_records = len(data_list)
    for i in range(0, total_records, batch_size):
        batch_data = data_list[i:i+batch_size]
        table.append(batch_data)
        table.flush()

#### Add train dataset

In [796]:
# Open an existing PyTable file or create a new one
with tables.open_file('/media/local-admin/lightmycells/Study_patches_s_3/tools/pytables/study_patches_d_train.pytable', mode='a') as h5file:

    # Access the desired group or table
    img_earray = h5file.root.img.data
    mask_earray = h5file.root.mask.data

    patches = []
    for record in train_data:
        for patch in tile_image(record):
            patches.append(np.squeeze(patch))                     # Is the order respected here?
            print(patch.dtype)
    batch_write_to_pytables(patches, img_earray)

    patches = []
    for record in train_targets:
        for patch in tile_image(record):
            patches.append(patch)
            print(patch.dtype)
    batch_write_to_pytables(patches, mask_earray)    

uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16

#### Add validation dataset

In [797]:
# Open an existing PyTable file or create a new one
with tables.open_file('/media/local-admin/lightmycells/Study_patches_s_3/tools/pytables/study_patches_d_valid.pytable', mode='a') as h5file:

     # Access the desired group or table

    data_shape = validation_data.shape
    shape_suffix = f"{data_shape[-2]}"  # Assuming the shape you're interested in is in the second-to-last dimension
    
    # Access the desired group or table
    img_earray = h5file.root.img[f'data_{shape_suffix}']
    mask_earray = h5file.root.mask[f'data_{shape_suffix}']
    
    # Iterate through your data and append it to the EArray
    for data_row in validation_data:
        data_row = np.expand_dims(data_row, axis=0)
        img_earray.append(data_row)
        print(data_row.dtype)

    for data_row in validation_targets:
        data_row = np.expand_dims(data_row, axis=0)
        mask_earray.append(data_row)
        print(data_row.dtype)

uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16


#### Add test dataset

In [798]:
# Open an existing PyTable file or create a new one
with tables.open_file('/media/local-admin/lightmycells/Study_patches_s_3/tools/pytables/study_patches_d_test.pytable', mode='a') as h5file:

    # Access the desired group or table

    data_shape = test_data.shape
    shape_suffix = f"{data_shape[-2]}"  # Assuming the shape you're interested in is in the second-to-last dimension
    
    # Access the desired group or table
    img_earray = h5file.root.img[f'data_{shape_suffix}']
    mask_earray = h5file.root.mask[f'data_{shape_suffix}']
    
    # Iterate through your data and append it to the EArray
    for data_row in test_data:
        data_row = np.expand_dims(data_row, axis=0)
        img_earray.append(data_row)
        print(data_row.dtype)

    for data_row in test_targets:
        data_row = np.expand_dims(data_row, axis=0)
        mask_earray.append(data_row)
        print(data_row.dtype)

#### Check pytable structures

In [799]:
train_file = '/media/local-admin/lightmycells/Study_patches_s_3/tools/pytables/study_patches_d_train.pytable'
test_file = '/media/local-admin/lightmycells/Study_patches_s_3/tools/pytables/study_patches_d_test.pytable'
valid_file = '/media/local-admin/lightmycells/Study_patches_s_3/tools/pytables/study_patches_d_valid.pytable'

# Open the PyTable file
with tables.open_file(train_file, mode='r') as h5file:
    # Print the structure of the file
    print(h5file)

    # Access and print information about groups
    for group in h5file.walk_groups():
        print(f"Group: {group}")

    # Access and print information about arrays
    for array in h5file.walk_nodes(where='/', classname='Array'):
        print(f"Array: {array}")

/media/local-admin/lightmycells/Study_patches_s_3/tools/pytables/study_patches_d_train.pytable (File) ''
Last modif.: '2024-04-15T12:39:30+00:00'
Object Tree: 
/ (RootGroup) ''
/img (Group) 'Image data'
/img/data (EArray(3375, 256, 256)zlib(9)) ''
/mask (Group) 'Mask data'
/mask/data (EArray(3375, 4, 256, 256)zlib(9)) ''

Group: / (RootGroup) ''
Group: /img (Group) 'Image data'
Group: /mask (Group) 'Mask data'
Array: /img/data (EArray(3375, 256, 256)zlib(9)) ''
Array: /mask/data (EArray(3375, 4, 256, 256)zlib(9)) ''


In [800]:
train_targets.shape

(135, 4, 1024, 1024)

In [801]:
train_data.shape

(135, 1, 1024, 1024)

In [802]:
tables.file._open_files.close_all()

#### Old pytable (irrelevant)

In [803]:
# Define the shape of your image and mask data
#img_shape = train_data[0].shape 
#mask_shape = train_targets[0].shape

#numpixels
#pixel_list = []
#for target in train_targets:
#     pixel_list.append(np.count_nonzero(target))
#pixel_array = np.array(pixel_list)
#pixel_array = np.reshape(pixel_array, (len(train_targets), 1))

## Verify the shapes of the loaded data
#print("Image data shape:", train_data.shape)
#print("Mask data shape:", train_data.shape)

##data_shape
#data_shape = train_data.shape[0]

#pytable_fname = '/media/local-admin/galaxy/lightmycells_storage/Study_patches_s/tools/pytables/train/study_patches_s_Study_24_train.pytable'
## Create PyTable instance
#custom_pytable = CustomPyTable(pytable_fname, img_shape, mask_shape)

## Create PyTable file and populate with image and mask data
#custom_pytable.create_pytable(train_data, train_targets, pixel_array, data_shape)

In [804]:
#img_shape = validation_data[0].shape 
#mask_shape = validation_targets[0].shape

#numpixels
#pixel_list = []
#for target in validation_targets:
#     pixel_list.append(np.count_nonzero(target))
#pixel_array = np.array(pixel_list)
#pixel_array = np.reshape(pixel_array, (len(validation_targets), 1))

## Verify the shapes of the loaded data
#print("Image data shape:", validation_data.shape)
#print("Mask data shape:", validation_data.shape)

##data_shape
#data_shape = validation_data.shape

#pytable_fname = '/media/local-admin/galaxy/lightmycells_storage/Study_patches_s/tools/pytables/valid/study_patches_s_Study_24_valid.pytable'
## Create PyTable instance
#custom_pytable = CustomPyTable(pytable_fname, img_shape, mask_shape)

## Create PyTable file and populate with image and mask data
#custom_pytable.create_pytable(validation_data, validation_targets, pixel_array, data_shape)

In [805]:
#img_shape = test_data[0].shape 
#mask_shape = test_targets[0].shape

##numpixels
#pixel_list = []
#for target in test_targets:
#     pixel_list.append(np.count_nonzero(target))
#pixel_array = np.array(pixel_list)
#pixel_array = np.reshape(pixel_array, (len(test_targets), 1))

## Verify the shapes of the loaded data
#print("Image data shape:", test_data.shape)
#print("Mask data shape:", test_data.shape)

##data_shape
#data_shape = test_data.shape

#pytable_fname = '/media/local-admin/galaxy/lightmycells_storage/Study_patches_s/tools/pytables/test/study_patches_s_Study_24_test.pytable'
## Create PyTable instance
#custom_pytable = CustomPyTable(pytable_fname, img_shape, mask_shape)

## Create PyTable file and populate with image and mask data
#custom_pytable.create_pytable(test_data, test_targets, pixel_array, data_shape)