In [4]:
from tqdm import tqdm
import torch
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import re
from torch.utils.data import Dataset
from scipy.stats import sem
from tabulate import tabulate
import datetime
import matplotlib.colors as mcolors
import random
import pickle
import pandas as pd

In [5]:
def create_sparse_signal(input_image_batch, signal_points=2, linear=False):
    # Take as input a torch tensor in form [batch_size, 1, x_dim, y_dim]]
    # Create a copy of the input image batch
    image_batch = input_image_batch.clone()

    # Flatten the image tensor
    flat_batch = image_batch.view(image_batch.size(0), -1)

    # Count the number of non-zero values in each image
    nz_counts = torch.sum(flat_batch != 0, dim=1)

    # Find the indices of the images that have more non-zero values than signal_points
    sparse_indices = torch.where(nz_counts > signal_points)[0]

    # For each sparse image, randomly select signal_points non-zero values to keep
    for idx in sparse_indices:
        # Find the indices of the non-zero values in the flattened image
        nz_indices = torch.nonzero(flat_batch[idx]).squeeze()

        # Randomly select signal_points non-zero values to keep
        if linear:
            kept_indices = torch.linspace(0, nz_indices.numel() - 1, steps=signal_points).long()
        else:
            kept_indices = torch.randperm(nz_indices.numel())[:signal_points]

        # Zero out all non-selected values
        nonkept_indices = nz_indices[~torch.isin(nz_indices, nz_indices[kept_indices])]
        flat_batch[idx, nonkept_indices] = 0

    # Reshape the flat tensor back into the original shape
    output_image_batch = flat_batch.view_as(image_batch)

    return output_image_batch

def add_noise_points_to_batch_prenorm(input_image_batch, noise_points=100, time_dimension=100):
    image_batch = input_image_batch.clone()

    if noise_points > 0:
        #Find dimensions of input image 
        x_dim = image_batch.shape[2]
        y_dim = image_batch.shape[3]

        #For each image in the batch
        for image in image_batch:

            # Create a list of unique random x and y coordinates
            num_pixels = x_dim * y_dim
            all_coords = np.arange(num_pixels)
            selected_coords = np.random.choice(all_coords, noise_points, replace=False)
            x_coords, y_coords = np.unravel_index(selected_coords, (x_dim, y_dim))
            
            # Iterate through noise_points number of random pixels to noise
            for i in range(noise_points):

                # Add a random number between recon_threshold and 1 to the pixel 
                image[0][x_coords[i], y_coords[i]] = np.random.uniform(0, time_dimension)

    
    return image_batch

In [6]:
def create_data(path, input_path, signal_points=30, noise_points=100, time_dimension=1000, plot=False):
    # if not already exists then create a folder to store the new data in path\\Labels
    os.makedirs(path + '\\Labels\\', exist_ok=True)
    os.makedirs(path + '\\Labels_Sparse\\', exist_ok=True)
    os.makedirs(path + '\\Data\\', exist_ok=True)

    # load each image in path in turn
    for image_path in tqdm(os.listdir(input_path + '\\Data\\'), desc='Creating Data', unit='image'):
        # load npy file
        image = np.load(input_path + '\\Data\\' + image_path)
        #turn image into tensor
        image_tensor = torch.from_numpy(image)
        # add two dimensions to start of tensor
        image_tensor = image_tensor.unsqueeze(0).unsqueeze(0)
        sparse_output = create_sparse_signal(image_tensor, signal_points)
        #sparse_and_resolution_limited = simulate_detector_resolution(sparse_output_batch, x_std_dev_r, y_std_dev_r, tof_std_dev_r)
        noised_sparse_reslimited = add_noise_points_to_batch_prenorm(sparse_output, noise_points, time_dimension)
        #remove first two dims
        noised_sparse_reslimited = noised_sparse_reslimited.squeeze(0).squeeze(0)
        # convert to numpy array
        noised_sparse_reslimited = noised_sparse_reslimited.numpy()
        
        if plot:
            # plot each image sideby side for comparison
            fig, ax = plt.subplots(1, 2, figsize=(10, 5))
            ax[0].imshow(image, cmap='gray')
            ax[0].set_title('Original Image')
            ax[1].imshow(noised_sparse_reslimited, cmap='gray')
            ax[1].set_title('Noised, Sparse and Resolution Limited Image')
            plt.show()

        # save the new image in path\\Data
        np.save(path + '\\Data\\' + image_path, noised_sparse_reslimited)
        # save the image in path\\Labels
        np.save(path + '\\Labels\\' + image_path, image)
        # save the image in path\\Labels
        np.save(path + '\\Labels_Sparse\\' + image_path, sparse_output.squeeze(0).squeeze(0).numpy())

 
create_data(r"N:\\Yr 3 Project Datasets\\PERF VALIDATION SETS\\a10K 100N 30S\\", r"N:\\Yr 3 Project Datasets\\RDT 10K MOVE\\")

Creating Data:   0%|          | 0/10000 [00:00<?, ?image/s]

Creating Data: 100%|██████████| 10000/10000 [00:54<00:00, 182.81image/s]
