In [None]:
### This example notebook is to integrate DIMR / DeepSNiF analysis from the IMC_denoise package into the steinbock (https://bodenmillergroup.github.io/steinbock/) workflow
### Step of steinbock pipeline: after the conversion of mcd into .tiff files 

### Edited by Ben Caiello from the example DeepSNiF train/run jupyter notebook script in IMC_denoise
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import tifffile as tp
from IMC_Denoise.IMC_Denoise_main.DIMR import DIMR
from IMC_Denoise.IMC_Denoise_main.DeepSNiF import DeepSNiF
from IMC_Denoise.DeepSNiF_utils.DeepSNiF_DataGenerator import DeepSNiF_DataGenerator

if 'generated_patches' in globals():
    del generated_patches

train_directory = 'C:/Users/....../training_img' # change this to the directory of your training images 
Raw_directory = 'C:/Users/....../img' # change this to the directory of pre-denoised tiffs you want to process (always ends with /img for steinbock exports).
output_directory = 'C:/Users/....../output' # change this directory to where you want the denoised images to be written to

# Make all three directories the same to train DeepSNiF on the whole dataset, then overwrite your current files with the denoised versions

#### Choose what channels you want to denoise:
channel_names = [ChannelNumber_1, ..., ChannelNumber_N]     # list of integers corresponding to the channels of interest / channels you want to denoise

# The most convenient way to integrate steinbock and IMC_denoise is to have the output directly overwrite the img folder.
# This code chunk creates a pre_denoised copy of the img folder if you plan to use the overwrite.
# This means if you need to undo/redo the denoising, steinbock will not need to convert the mcd's to tiffs again
if Raw_directory == output_directory:    
    if os.path.isdir(Raw_directory + "_pre_denoise") == False:
        shutil.copytree(Raw_directory,(Raw_directory + "_pre_denoise"))
    else:
        print('img_pre_denoise folder already exists, will not copy current img directory')
        
        


In [None]:
## Step 0.1: set up a function that integrates all the training functions of IMC_denoise
# Just done here so that the code for iterating through the channels is simpler
# DO note the -- run_type = 'multi_channel_tiff' --  attribute in the DataGenerator call: this is what allows the ingestion of multi-channel .tiffs
# Without specifying run_type, IMC Denoise should default to reading the official, single-channel .tiff format
def DeepSNiF_train(channel_name, n_neighbours = 4, n_iter = 3, window_size = 3, train_epoches = 25, train_initial_lr = 1e-3, 
                  train_batch_size = 128, pixel_mask_percent = 0.2, val_set_percent = 0.15, loss_function = "I_divergence",
                  weights_name = None, loss_name = None, weights_save_directory = None, lambda_HF = 3e-6):
    '''
    This function merges the training functions of IMC_denoise for simplicity of the script here, particularly with most/all default setting for hyperparamters.
    If you want more control / want to adjust many of the hyperparameters, it may be better to split this function into its original pieces. 
    This function is also set to work only with multi-channel .tiffs, but that can be edited in the DataGenerator call.
    '''
    DataGenerator = DeepSNiF_DataGenerator(run_type = 'multi_channel_tiff', channel_name = channel_name, ratio_thresh = 0.8,
                                           patch_row_size = 64, patch_col_size = 64, row_step = 60, col_step = 60,
                                           n_neighbours = n_neighbours, n_iter = n_iter, window_size = window_size)
    generated_patches = DataGenerator.generate_patches_from_directory(load_directory = train_directory)
    print('The shape of the generated training set for channel ' + str(channel_name)  + ' is ' + str(generated_patches.shape) + '.')
    is_load_weights = False # Use the trained model directly. Will not read from saved one.
    deepsnif = DeepSNiF(train_epoches = train_epoches, 
                    train_learning_rate = train_initial_lr,
                    train_batch_size = train_batch_size,
                    mask_perc_pix = pixel_mask_percent,
                    val_perc = val_set_percent,
                    loss_func = loss_function,
                    weights_name = weights_name,
                    loss_name = loss_name,
                    weights_dir = weights_save_directory, 
                    is_load_weights = is_load_weights,
                    lambda_HF = lambda_HF)
    train_loss, val_loss = deepsnif.train(generated_patches)
    return(n_neighbours, n_iter, window_size, train_loss, val_loss, deepsnif)


In [None]:
# Step 1: iterate through the channels, training  and then running for each image
 # This code block makes it so that if no channel names are specified, all channels are run! Depending on how many channels / how long the training -- that would take a very long time
if len(channel_names) < 1:
    print('No channels specified! Will denoise all channels in provided .tiffs')
    first_image_path = os.path.join(Raw_directory, os.listdir(Raw_directory)[0])
    with TiffFile(first_image_path) as tif:
        channel_number = len(tif.pages)
    for j in range(channel_number):
        channel_names.append(j)

# This loop iterates through the list of channels and performs DIMR + DeepSNiF on each:
for i in channel_names:
    n_neighbours, n_iter, window_size, train_loss, val_loss, deepsnif = DeepSNiF_train(i)
    for img in os.listdir(Raw_directory):
        img_path = os.path.join(Raw_directory, img)
        with tp.TiffFile(img_path) as tif:
            Img_raw = tif.pages[i].asarray()
        Img_DIMR_DeepSNiF = deepsnif.perform_IMC_Denoise(Img_raw, n_neighbours = n_neighbours, n_iter = n_iter, window_size = window_size)
        numpy_tiff = tp.imread(img_path)
        numpy_tiff[i] = Img_DIMR_DeepSNiF
        tp.imwrite(os.path.join(output_directory, img),numpy_tiff, photometric='minisblack')
    #This if statement for cases when you are running multiple channels at once and not overwriting the /img folder. It ensures the first N channels you denoise are not overwritten when denoising channel N+1.
    if Raw_directory != output_directory:
        Raw_directory = output_directory

# Steinbock should now be able to seemlessly work with the denoised files