To execute e cell with code, select it & press Shift+Enter

# Calculate flatfield shading correction from all measured datasets, using BaSiC

More info about BaSiC:  
[https://github.com/peng-lab/BaSiCPy](https://github.com/peng-lab/BaSiCPy)  
[doi: 10.1038/ncomms14836](http://www.nature.com/articles/ncomms14836)

## Define file-paths:

Your files should be saved with the following structure:  
Make sure to call the different folders 'round*'  
('*' can be any characters)
```
base_dir
├───round1
│   ├───*A1*
│   │   └───*_Plate_*
│   │       └───TimePoint_1
│   │           ├───ZStep_1
│   │           ├───ZStep_2
|   │           .
|   │           .
│   ├───*A2*
│   │   └───*_Plate_*
│   │       └───TimePoint_1
│   │           ├───ZStep_1
│   │           ├───ZStep_2
|   │           .
|   │           .
|   │   .
|   │   .
├───round2
│   ├───*A1*
|   .
|   .
```

In the cell below, change the following file-paths:
* `base_dir`: the directory where all rounds are in
* `save_dir`: the directory where the shading corrections will be saved

Note: always use double backslashes '\\' to separate folders

In [None]:
base_dir = 'Z:\\zmbstaff\\9780\\Raw_Data\\MD_1'
save_dir = 'Z:\\zmbstaff\\9780\\Processed_Data\\MD_1\\shading_corrections_BaSiC'

The rest of the cells can be executed 'as is':

## Import modules:

In [None]:
import os
import glob
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tifffile

from basicpy import BaSiC
import jax
jax.config.update('jax_platform_name', 'cpu')

from zmb_hcs.parser_MD import parse_files_zmb, get_well_image_FCZYX
from zmb_hcs.BaSiC_helper_funs import get_middle_slice

## Find & structure files:

In [None]:
# load all files into pandas dataframe
plate_dirs = glob.glob(base_dir+'\\round*\\*')
files_list = []
for plate_dir in plate_dirs:
    files, _ = parse_files_zmb(plate_dir, query="")
    files['plate_dir'] = plate_dir
    files_list.append(files)
files = pd.concat(files_list)
files = files[files.z.notnull()]
files = files.reset_index()
channels = files['channel'].unique()

In [None]:
# get channel_names of files
for plate_dir in plate_dirs:
    files_plate_dir = files.query("plate_dir==@plate_dir")
    for channel in files_plate_dir['channel'].unique():
        files_channel = files_plate_dir.query("channel==@channel")
        fn_sel = list(files_channel['path'])[0]
        with tifffile.TiffFile(fn_sel) as tif:
            metadata = tif.metaseries_metadata
            channel_name = metadata['PlaneInfo']['_IllumSetting_']
        indxs = files_channel.index
        files.loc[indxs,'channel_name'] = channel_name

## Find used channels:

In [None]:
channel_names = files.channel_name.unique()
channel_names

## Calculate flatfield with BaSiC:

For each channel:
* Find all plates in 'base_dir', that contain this channel
* From each of these tiles, load the middle slice & use BaSiC to calculate the flatfield image

In [None]:
basic_dict = {}
for channel_name in channel_names:
    print(f"Processing '{channel_name}' channel")
    print("loading data...")
    start_time = time.time()
    data_da, dx, dy = get_middle_slice(channel_name, files)
    data = data_da.compute()
    print(f"Took {np.round(time.time() - start_time)}s")
    print("running BaSiC...")
    start_time = time.time()
    basic = BaSiC(get_darkfield=False, smoothness_flatfield=1)
    basic.fit(data)
    basic_dict[channel_name] = basic
    print(f"Took {np.round(time.time() - start_time)}s\n")
print('Finished')

## Plot calculated flatfield images:

In [None]:
fig, axes = plt.subplots(1, len(channel_names), figsize=(15, 3))
for n, channel_name in enumerate(channel_names):
    axes[n].set_title(channel_name)
    im = axes[n].imshow(basic_dict[channel_name].flatfield)
    fig.colorbar(im, ax=axes[n])
fig.tight_layout()

## Save flatfield images in 'save_dir'

In [None]:
# save flatfield images
os.makedirs(save_dir, exist_ok=True)
for channel_name in channel_names:
    with tifffile.TiffWriter(save_dir+f'\\flatfield_{channel_name}.ome.tif', bigtiff=True) as tif:
        tif.write(basic_dict[channel_name].flatfield,
                  photometric='minisblack',
                  metadata={'axes': 'YX',
                            'PhysicalSizeX': dx,
                            'PhysicalSizeY': dy,
                           }
                 )