# Data Preprocessing to Transforming MPP data into images and predicting Transcription Rate (TS)

This notebook is ment to convert raw cell data from several wells into multichannel images (along with its corresponding metadata).

Data was taken from:
`/storage/groups/ml01/datasets/raw/20201020_Pelkmans_NascentRNA_hannah.spitzer/` and server `vicb-submit-01`. 

In the preprocessing done in this notebook, NO discrimination of channels is done! All the channels are saved in the same order and all of them are also projected into a scalars, in case another channel wants to be used as target varaible. The objective of this preprocessing is to create a 'imaged' version of the MPP data.

The discretization of the channels (input_channels) and the selection of the target variable is done during the convertion into tensorflow dataset!

Considerations:
- The MPPData saved in a dictionary and is never merged (only the metadata)! This is because there is not 'in place' option to extend numpy arrays (they need contiguous space in the memory). Therefore, any merging over numpy arrays results into object duplication in memory during the merging process.
- The images are saved as variables in each MPPData instance as arrays of dtype=np.uint16. This saves a lot of ram memory during the processing. However, this only allows values between 0 and 65535 (which is the measure range of MPP data). Therefore, the normalization is done image by image during the saving into disk, again to reduce the use of ram memory.

Load libraries:

In [None]:
import numpy as np
import pandas as pd
# To display all the columns
pd.options.display.max_columns = None
import os
import sys
import matplotlib.pyplot as plt
import json
import math
import matplotlib.pyplot as plt

Load Parameters:

In [None]:
# Do not touch the value of PARAMETERS_FILE!
# When this notebook is executed with jupyter-nbconvert (from script), 
# it will be replaced outomatically
PARAMETERS_FILE = 'dont_touch_me-input_parameters_file'
if not os.path.exists(PARAMETERS_FILE):
    raise Exception('Parameter file {} does not exist!'.format(PARAMETERS_FILE))
    
# Open parameters
with open(PARAMETERS_FILE) as params_file:
    p = json.load(params_file)
p.keys()

Take a look into the loaded parameters:

In [None]:
p

Set paths and Load external libraries:

In [None]:
# Load data path
DATA_DIR = p['raw_data_dir']
if not os.path.exists(DATA_DIR):
    raise Exception('Data path {} does not exist!'.format(DATA_DIR))
else:
    print('DATA_DIR: {}'.format(DATA_DIR))

# Load external libraries path
EXTERNAL_LIBS_PATH = p['external_libs_path']
if not os.path.exists(EXTERNAL_LIBS_PATH):
    raise Exception('External library path {} does not exist!'.format(EXTERNAL_LIBS_PATH))
else:
    print('EXTERNAL_LIBS_PATH: {}'.format(EXTERNAL_LIBS_PATH))
# Add EXTERNAL_LIBS_PATH to sys paths (for loading libraries)
sys.path.insert(1, EXTERNAL_LIBS_PATH)
# Load external libraries
from pelkmans.mpp_data import MPPData as MPPData
from pelkmans.mpp_data import save_to_file_targets_masks_and_normalized_images as normalize_and_save
from pelkmans.mpp_data import get_image_normalization_vals as get_normalization_vals
from pelkmans.mpp_data import get_concatenated_metadata as get_concatenated_metadata

# Set logging configuration
import logging
logging.basicConfig(
    filename=p['log_file'],
    filemode='w', 
    level=getattr(logging, p['log_level'])
)
logging.info('Parameters loaded from file:\n{}'.format(PARAMETERS_FILE))

Check available data (Perturbations and Wells):

In [None]:
logging.info('Reading local available perturbations-wells...')
# Save available local Perturbations and Wells
perturbations = [per for per in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, per))]
local_data = {}
#print('Local available perturbations-wells:\n')
for per in perturbations:
    pertur_dir = os.path.join(DATA_DIR, per)
    wells = [w for w in os.listdir(pertur_dir) if os.path.isdir(os.path.join(pertur_dir, w))]
    #print('{}\n\t{}\n'.format(p, wells))
    local_data[per] = wells

Select Perturbations and its wells to process: 

In [None]:
msg = 'Local available perturbations-wells:\n{}'.format(local_data)
print(msg)
logging.debug(msg)

# In case you only want to load some specific perturbations and/or wells here:
#selected_data = {
#    '184A1_hannah_unperturbed': ['I11', 'I09'],
#    '184A1_hannah_TSA': ['J20', 'I16'],
#}

# Load perturbations-wells from parameters file
selected_data = p['perturbations_and_wells']
# How many wlls will be processed?
n_wells = 0
for key in list(selected_data.keys()):
    n_wells += len(selected_data[key])

print('\nSelected perturbations-wells:\n{}'.format(selected_data))

#Generate and save data dirs
data_dirs = []
for per in selected_data.keys():
    for w in selected_data[per]:
        d = os.path.join(DATA_DIR, per, w)
        data_dirs.append(d)
        if not os.path.exists(d):
            msg = '{} does not exist!\nCheck if selected_data contain elements only from local_data dict.'.format(d)
            logging.error(msg)
            raise Exception(msg)
p['data_dirs'] = data_dirs

Process data:

In [None]:
msg = 'Starting processing of {} wells...'.format(n_wells)
logging.info(msg)

data = {
    'train':[],
    'val':[],
    'test':[]
}

for w, data_dir in enumerate(p['data_dirs'], 1):
    msg = 'Processing well {}/{} from dir {}...'.format(w, n_wells, data_dir)
    logging.info(msg)
    print('\n\n'+msg)
    # Load data as an MPPData object
    mpp_temp = MPPData.from_data_dir(data_dir,
                                     dir_type=p['dir_type'],
                                     seed=p['seed'])
    
    # Add cell cycle to metadata (G1, S, G2)
    # Important! If mapobject_id_cell is not in cell_cycle_file =>
    # its corresponding cell is in Mitosis phase!
    if p['add_cell_cycle_to_metadata']:
        msg = 'Adding cell cycle to metadata...'
        logging.info(msg)
        print(msg)
        mpp_temp.add_cell_cycle_to_metadata(os.path.join(DATA_DIR, p['cell_cycle_file']))
    
    # Add well info to metadata
    if p['add_well_info_to_metadata']:
        msg = 'Adding well info to metadata...'
        logging.info('Adding well info to metadata...')
        print(msg)
        mpp_temp.add_well_info_to_metadata(os.path.join(DATA_DIR, p['well_info_file']))
    
    # Remove unwanted cells
    if p.get('filter_criteria', None) is not None:
        msg = 'Removing unwanted cells...'
        logging.info(msg)
        print(msg)
        mpp_temp.filter_cells(p['filter_criteria'], p['filter_values'])

    # Subtract background  values for each channel
    if p['subtract_background']:
        print('Subtracting background...')
        mpp_temp.subtract_background(os.path.join(DATA_DIR, p['background_value']))
    
    # Project every uni-channel images into a scalar for further analysis
    if p['project_into_scalar']:
        msg = 'Projecting data...'
        logging.info(msg)
        print(msg)
        mpp_temp.add_scalar_projection(p['method'])
        
    # Split data into train, validation and test
    msg = 'Spliting data into train, validation and test'
    logging.info(msg)
    print(msg)
    train_temp, val_temp, test_temp = mpp_temp.train_val_test_split(p['train_frac'], p['val_frac'])
    del(mpp_temp)
    
    if p['convert_into_image']:
        msg = 'Converting data into images...'
        logging.info(msg)
        print(msg)
        train_temp.add_image_and_mask(data='MPP', remove_original_data=p['remove_original_data'], img_size=p['img_size'])
        msg = 'Train dataset converted'
        logging.info(msg)
        print(msg)
        val_temp.add_image_and_mask(data='MPP', remove_original_data=p['remove_original_data'], img_size=p['img_size'])
        msg = 'Validation dataset converted'
        logging.info(msg)
        print(msg)
        test_temp.add_image_and_mask(data='MPP', remove_original_data=p['remove_original_data'], img_size=p['img_size'])
        msg = 'Test dataset converted'
        logging.info(msg)
        print(msg)
    
    # Validate same channels across wells
    if len(data['val']) > 0:
        if not all(data['val'][0].channels.name == val_temp.channels.name):
            raise Exception('Channels across MPPData instances are not the same!')
            
    data['val'].append(val_temp)
    del(val_temp)
    data['test'].append(test_temp)
    del(test_temp)
    data['train'].append(train_temp)
    del(train_temp)
    

In [None]:
# During this preprocessing we save all channels!
channels_ids = data['val'][0].channels.channel_id.values

Get normalization values from training set:

In [None]:
# Normalization values got from the train data (inner percentile% of train data)
if p['normalise']:
    msg = 'Normalizing data...'
    logging.info(msg)
    if p['convert_into_image']:
        # Onle get the normalization parameters here and normalize 
        # images during saving into files
        rescale_values = get_normalization_vals(
            instance_dict=data['train'],
            input_channel_ids=channels_ids,
            percentile=p['percentile'])
    if not p['remove_original_data']:
        # Get normalization values and normalize original MPPData
        pass
        # TODO: Rewrite rescale_intensities_per_channel to deal with
        # a dictionary of MPPData instances, instead of a big numpy array
        # (this is to avoid the duplication in memory)
        #rescale_values = train.rescale_intensities_per_channel(percentile=p['percentile'], )
        #_ = val.rescale_intensities_per_channel(rescale_values=rescale_values)
        #_ = test.rescale_intensities_per_channel(rescale_values=rescale_values)
    p['normalise_rescale_values'] = list(rescale_values)

As an extra, all channels are projected into an scalar and saved in the metadata. The idea of this is to have scalar data to make simpler analysis. This projected data is also normalized using the train projection values.

In [None]:
# Merge metadata and Normalize projected values
msg = 'Normalizing Projected data...'
logging.info(msg)
    
metadata, rescale_values = get_concatenated_metadata(
    mppdata_dict=data,
    normalize=(p['normalise'] & p['project_into_scalar']),
    norm_key='train',
    projection_method=p['method'],
    percentile=p['percentile'])

p['normalise_rescale_values_scalars'] = list(rescale_values)

In [None]:
metadata

## Save data

Prepare to save data:

In [None]:
import shutil

msg = 'Starting data saving process...'
logging.info(msg)

# create dir
outdir = p['output_data_dir']
if os.path.exists(outdir):
    msg = 'Warning! Directory {} already exist! Deleting...\n'.format(outdir)
    logging.info(msg)
    print(msg)
    try:
        shutil.rmtree(outdir)
    except OSError as e:
        msg  = 'Dir {} could not be deleted!\n\nOSError: {}'.format(outdir, e)
        logging.info(msg)
        print(msg)

msg = 'Creating dir: {}'.format(outdir)
logging.info(msg)
print(msg)
os.makedirs(outdir, exist_ok=False)
    

Save Images, masks and targets of all channels into separated files using the mapobject_id_cell of each cell:

In [None]:
if p['convert_into_image']: 
    msg = 'Saving images and masks...'
    logging.info(msg)
    
    output_files = normalize_and_save(
        mppdata_dict=data,
        norm_vals=np.array(p['normalise_rescale_values']),
        channels_ids=channels_ids,
        projection_method=p['method'],
        outdir=outdir
    )
    p['output_files'] = output_files
    print(output_files)

Save metadata and used parameters

In [None]:
msg = 'Saving Parameters and Metadata...'
logging.info(msg)

# save params
json.dump(p, open(os.path.join(outdir, 'params.json'), 'w'), indent=4)

# save metadata
metadata.to_csv(os.path.join(outdir, 'metadata.csv'))

# Save used channels
#train.channels
data['train'][0].channels.to_csv(os.path.join(outdir, 'channels.csv'))

Finally, load one saved file and take a look into the content to see if everithing was done correctlly:

In [None]:
cell_id = np.random.choice(metadata['mapobject_id_cell'].values)
cell_subset = metadata.set[metadata.mapobject_id_cell == cell_id].values[0]
file = os.path.join(outdir, cell_subset,str(cell_id)+'.npz')
cell = np.load(file)
cell_img = cell['img']
cell_mask = cell['mask']
cell_targets = cell['targets']

print('Cell image shape: {}\n'.format(cell_img.shape))
print('Cell mask shape: {}\n'.format(cell_mask.shape))
print('Cell target shape: {}\n'.format(cell_targets.shape))
print('Cell targets: {}\n'.format(cell_targets))

# Now take a look into its image
plt.figure(figsize=(10,10))
plt.imshow(cell_img[:,:,0:3],
            cmap=plt.cm.PiYG,
            vmin=0, vmax=1
          )
plt.show()
logging.info('\n\nPREPROCESSING FINISHED!!!!----------------------')