In [1]:
from tiff_processing_utils import tile_tiff_pair, rasterize_shapefile
from MemoryMapDataset import MemmapDataset
import numpy as np
import psutil
import os
import gc

The processing pipeline assumes that the data in the Chunks folder is in the following format:
- Each chunk is in it's own folder and named 'Chunk x' or 'Chunk x x-x'
- The RGB tif should be named 'Chunkx.tif' or 'Chunkx_x-x.tif'
- label shape file and corresponding label files should be in a folder called 'labels' inside of the matching 'Chunk x' / 'Chunk x x-x' folder, the names of the files do not need to be formatted.

In [3]:
base_path = "C:/Users/gwrye/OneDrive/Desktop"

TILE_SIZE = 256

combined_images_file = os.path.join(base_path, f'{TILE_SIZE}dataset_images.npy')
combined_labels_file = os.path.join(base_path, f'{TILE_SIZE}dataset_labels.npy')

# RAM thresholds
TOTAL_RAM_MB = psutil.virtual_memory().total / (1024 ** 2)
SAFE_RAM_USAGE_MB = TOTAL_RAM_MB - 16 * 1024  # 16GB below total RAM
CHUNK_BUFFER_SIZE = 15  # Number of chunks to keep in memory at a time

In [None]:
# convert all label shape files to tif
for entry in os.listdir(base_path):
    if 'Chunk' in entry:
        chunk_path = os.path.join(base_path, entry)
        rasterized_shape = rasterize_shapefile(chunk_path)
print('\nDone rasterizing shapefiles')


In [4]:
# Functions to monitor memory usage
def print_memory_usage():
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Memory Usage: {mem_info.rss / (1024 ** 2):.2f} MB")

def get_memory_usage():
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    return mem_info.rss / (1024 ** 2)  # Return memory usage in MB

In [None]:
"""
Convert all tif pairs into tiled datasets

NOTE: This will take a lot of time, memory, and storage space.
You should have at least 32GB of RAM and triple the chunk folder size of storage. If you don't have enough RAM,
you can run this script in smaller chunks by lowering the CHUNK_BUFFER_SIZE variable.
"""

# Function to append data to memory-mapped file
def append_to_memmap(file_path, data, dtype):
    if not os.path.exists(file_path):
        print(f"Creating new memmap file at {file_path}")
        new_memmap = np.lib.format.open_memmap(file_path, mode='w+', dtype=dtype, shape=data.shape)
        new_memmap[:] = data
    else:
        # Load the existing memmap
        memmap = np.load(file_path, mmap_mode='r+')
        new_shape = (memmap.shape[0] + data.shape[0],) + memmap.shape[1:]
        
        # Create a temporary memmap with the expanded size
        temp_file_path = file_path + '.tmp'
        new_memmap = np.lib.format.open_memmap(temp_file_path, mode='w+', dtype=dtype, shape=new_shape)
        
        # Copy old data into the temporary memmap
        new_memmap[:memmap.shape[0]] = memmap[:]
        
        # Append new data
        new_memmap[memmap.shape[0]:] = data
        
        # Flush and delete the old memmap
        del memmap
        new_memmap.flush()
        
        # Replace the original file with the temporary file
        os.replace(temp_file_path, file_path)

# Buffer for storing data before appending to memmap
image_buffer = []
label_buffer = []

num_chunks = len([entry for entry in os.listdir(base_path) if 'Chunk' in entry])
print(f"Processing {num_chunks} chunk directories")

# Iterate over each chunk directory and process TIFF pairs
current_chunk = 0
for entry in os.listdir(base_path):
    if 'Chunk' in entry:
        current_chunk += 1
        print(f"\nChunk {current_chunk}/{num_chunks}")
        chunk_path = os.path.join(base_path, entry)
        
        # Generate tiled images and labels
        images, labels = tile_tiff_pair(chunk_path, image_size=TILE_SIZE)
        if images[0].size == 0:
            print(f"No valid tiles found at {entry}")
            continue
        
        # Add to buffer
        image_buffer.append(images)
        label_buffer.append(labels)

        # Check memory usage and append to memmap if within threshold
        current_memory_usage = get_memory_usage()
        if current_memory_usage > SAFE_RAM_USAGE_MB or current_chunk % CHUNK_BUFFER_SIZE == 0:
            if current_memory_usage > SAFE_RAM_USAGE_MB:
                print(f"Memory usage {current_memory_usage:.2f} MB exceeds {SAFE_RAM_USAGE_MB} threshold. Appending to memmap.")
            else:
                print("Appending to memmap...")
            images_to_append = np.concatenate(image_buffer, axis=0)
            labels_to_append = np.concatenate(label_buffer, axis=0)
            append_to_memmap(combined_images_file, images_to_append, np.uint8)
            append_to_memmap(combined_labels_file, labels_to_append, np.uint8)
            
            # Clear buffer
            image_buffer = []
            label_buffer = []
            
            # Memory management
            print_memory_usage()
            del images_to_append, labels_to_append
            gc.collect()

# Final append if buffer is not empty
if image_buffer:
    print("Appending remaining buffered data to memmap.")
    images_buffer = np.concatenate(image_buffer, axis=0)
    labels_buffer = np.concatenate(label_buffer, axis=0)
    append_to_memmap(combined_images_file, images_buffer, np.uint8)
    append_to_memmap(combined_labels_file, labels_buffer, np.uint8)
    
    # Clear buffer
    image_buffer = []
    label_buffer = []

print('\nDone tiling tif pairs')


In [5]:
# Shuffle data one entry at a time using Fisher-Yates shuffle
# This is necessary because the data is too large to load into memory all at once
def shuffle_data(images, labels):
    dataset_size = images.shape[0]

    for i in range(dataset_size-1, 0, -1):
        print(f"Percent Shuffled: {100*(dataset_size-i)/dataset_size:.2f}%", end='\r')
        j = np.random.randint(0, i+1)
        images[i], images[j] = images[j], images[i]
        labels[i], labels[j] = labels[j], labels[i]

images = np.load(combined_images_file, mmap_mode='r+')
labels = np.load(combined_labels_file, mmap_mode='r+')

shuffle_data(images, labels)

Percent Shuffled: 100.00%

In [None]:
# Example dataset
dataset = MemmapDataset(images, labels)
print(f"Dataset length: {len(dataset)}")
print(f"Dataset image shape: {dataset.images[0].shape}")
print(f"Dataset label shape: {dataset.labels[0].shape}")

In [None]:
# If labels.tif files are no longer needed

for entry in os.listdir(base_path):
    if 'Chunk' in entry:
        chunk_path = os.path.join(base_path, entry)
        os.remove(os.path.join(chunk_path, 'labels.tif'))