
## Spectral filtering of multichannels images

### <font color='red'> After clicking on a code cell, press "Shift+Enter" to run the code, or click on the "Run" button in the toolbar above.<br>

### Replace "..." signs with the appropriate path to your data.
</font>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tifffile
from pathlib import Path
from tapenade import get_path_to_demo_folder

## Loading spectral patterns

In [None]:
path_to_data = get_path_to_demo_folder()

channels=4 #fill up here the number of channels
species = np.array([1,2,3,4])
#species = np.array([1,2]) #for GFP and Alexa488
spectralpatterns=np.zeros((221,channels,species.shape[0]))

for i in species:
    filepath = Path(path_to_data) / f'species{i}_medfilt.npy'
    with open(filepath, 'rb') as f2:
        avgspectrum_i=np.load(f2)
    spectralpatterns[:,:,i-1] = avgspectrum_i[:,:]

## Filtering

In [None]:
img = tifffile.imread(path_to_data / '03_Hoechst_Ecad_Bra_Sox2.tif')

# Initialize filtered image array with same shape as input, int16 for intermediate calculations
image_filtered = np.zeros_like(img).astype(np.int16)

for z in range(img.shape[0]):
    # Compute mean intensity for each channel at z-slice
    Iavg_channels = np.mean(img[z,:,:,:], axis=(1,2))
    # Create diagonal normalization matrix
    D = np.diag(1 / Iavg_channels)
    # Prepare spectral patterns for this z-slice
    specpatterns_z = np.zeros((species.shape[0], img.shape[1]))
    for i in species:
        specpatterns_z[i-1, :] = spectralpatterns[z, :, i-1]
    # Compute unmixing weights
    w = np.linalg.inv(specpatterns_z @ D @ np.transpose(specpatterns_z)) @ specpatterns_z @ D
    # Apply unmixing to each species
    for i in species:
        image_filtered[z, i-1, :, :] = sum([w[i-1, j] * img[z, j, :, :] for j in range(0, channels)])

# Clip negative values and convert to uint16 for saving
image_filtered[image_filtered < 0] = 0
image_filtered = image_filtered.astype(np.uint16)



## Plots results

In [None]:
z_to_plot = 50 # index of the z-slice to plot

fig1, ax1 = plt.subplots(1,channels, figsize=(20, 5))            

for i in range(channels):
    ax1[i].imshow(img[z_to_plot,i,:,:])

fig1.suptitle('Channels before spectral filtering')
fig2, ax2 = plt.subplots(1,channels, figsize=(20, 5))

for i in range(channels):
    ax2[i].imshow(image_filtered[z_to_plot,i,:,:])

fig2.suptitle('Chanels after spectral filtering')
fig1.tight_layout()
fig2.tight_layout()