# Fine Tuning on DCC

We will be using this notebook to show how to fine tune the network with new data

**User inputs are capitalize throughout this notebook**

In [1]:
# imports
import sys
sys.path.append("/home/gabriel/dev/automatic-transient-detection")
from pathlib import Path
import shutil
import collections
import numpy as np
from tqdm import tqdm
from skimage.external import tifffile

# DataBase handler
import database

# Network handler
from ATD.common.experiment import Experiment

# Various utils functions
from ATD.utils import ioutils as iou
from ATD.utils import arrayutils as au
from ATD.utils import calciumutils as cu
from ATD.utils import listutils as lu
from ATD.utils import datasetutils as du

# because I am editing this file for now
%load_ext autoreload
%autoreload 1
%aimport database

## Database Creation
The following section provides the code to create the databse from a [google spreadsheet](https://docs.google.com/spreadsheets/d/12tEG3494V3y24qU8f11Vzui85sqt3VTwUzHBfEl0sQA/edit#gid=409641959)


In [2]:
CREATE_DATABASE = True
DATABASE_FILE = Path("../h5file/dcc_db.h5") # location of the database
SPREADSHEET = 'Raw Data List' # title of the spreadsheet document
WORKSHEET = 'DCC' # title of the worksheet in the spreadsheet
secretJSON = '../bot-reader-key.json' # file containing the authorizations to read the spreadsheet

In [3]:
# Get a worksheet object we will read with
worksheet = database.Database.oauth_and_get_worksheet(SPREADSHEET, WORKSHEET, secretJSON)

# Remove an existing database if it exists and CREATE_DATABASE is True
if CREATE_DATABASE and DATABASE_FILE.exists():
    DATABASE_FILE.unlink()
    
# Instanciate the database object and populate the databse from the worksheet
db = database.Database(DATABASE_FILE)
if CREATE_DATABASE:
    db.load_streams_from_worksheet(worksheet)
db.close()

Creating a new database at ../h5file/dcc_db.h5...
Adding the stream
Imaging of neuron 0 on coverslip 0
Experiment SEP-DCC with stream condition ctrl
The stream Stream_GluA1_001.tif was image on 2018-06-18 00:00:00
Unique stream-id is: stream-0
A total of 1 streams was added...


## Inference of the network on the movie

In [4]:
# We will try on both of those network
PRE_TRAINED_NETWORK_PATH = Path("/home/gabriel/results/atd/unet2d/unet2d-foldB-2fgr") # the pre-trained network
FINE_TUNED_NETWORK_PATH = Path("/home/gabriel/results/fine-tuned/sted/test/") # the new fine-tuned network

postprocess_params = {
    'minimal_time': 1,
    'minimal_height': 1,
    'minimal_width': 1,
    'threshold': 0.5
}

In [5]:
# Define the network experiment wrapper
pretrained_xp = Experiment(str(PRE_TRAINED_NETWORK_PATH), save=False, verbose=False)
finetuned_xp = Experiment(str(FINE_TUNED_NETWORK_PATH), save=False, verbose=False)

# Load the database
db = database.Database(DATABASE_FILE)

# We will only infer on the first stream
stream = db[0]

db.segment_and_detect_msct(stream, finetuned_xp, postprocess_params)

segmentation = stream['segmentation'][...]
tifffile.imsave('/home/gabriel/results/fine-tuned/dcc/segmentation_finetuned_sted.tif', segmentation.astype(np.uint16))


Segmentation and detection of 1 streams...
Currently doing stream :
Imaging of neuron 0 on coverslip 0
Experiment SEP-DCC with stream condition ctrl
The stream Stream_GluA1_001.tif was image on 2018-06-18 00:00:00
Unique stream-id is: stream-0
Stream shape: (1000, 512, 512)
Infering on the stream...


100%|██████████| 1000/1000 [03:40<00:00,  4.54it/s]
0 Regions removed so far: 100%|██████████| 539/539 [00:00<00:00, 1549.08it/s]


Computing centroids of events...
There is 539 detected events, 0 removed


## MiniFinder mSCTs Region Props Extraction
Since we are using the MiniFinder analysis software (the mask location are listed in the column Mini-Finder in the worksheet), we need to extract the properties of the segmented regions from the matrix.

In [6]:
EXTRACT_MINI_REGION_PROPS = True
REDO = True # if redo is True, if the properties already exist, they will be recalculated
MINI_REGIONPROPS_PARAMS = {
    'msct-props': [
        'area', 
        'convex_area',
        'bbox',
        'max_intensity',
        'major_axis_length',
        'minor_axis_length'
    ]
}

In [7]:
# For this specific example, we dont have the mask of the last stream, 
# so we wont be extracting props of this stream
db = database.Database(DATABASE_FILE)
if EXTRACT_MINI_REGION_PROPS:
    streams = list(db)
    streams.pop(-1)
    db.mini_regionprops_extraction(streams, MINI_REGIONPROPS_PARAMS, redo=REDO)
db.close()


Extracting mini's properties of stream /streams/stream-0


34 Regions removed so far: 100%|██████████| 323/323 [00:04<00:00, 79.08it/s]


Computing centroids of events...
There is 289 detected events, 34 removed

Extracting mini's properties of stream /streams/stream-1


23 Regions removed so far: 100%|██████████| 135/135 [00:01<00:00, 95.37it/s]


Computing centroids of events...
There is 112 detected events, 23 removed

Extracting mini's properties of stream /streams/stream-2


93 Regions removed so far: 100%|██████████| 736/736 [00:12<00:00, 60.05it/s] 


Computing centroids of events...
There is 643 detected events, 93 removed

Extracting mini's properties of stream /streams/stream-3


24 Regions removed so far: 100%|██████████| 642/642 [00:03<00:00, 181.76it/s]


Computing centroids of events...
There is 618 detected events, 24 removed

Extracting mini's properties of stream /streams/stream-4


112 Regions removed so far: 100%|██████████| 708/708 [00:15<00:00, 46.05it/s]


Computing centroids of events...
There is 596 detected events, 112 removed

Extracting mini's properties of stream /streams/stream-5


31 Regions removed so far: 100%|██████████| 431/431 [00:04<00:00, 100.24it/s]


Computing centroids of events...
There is 400 detected events, 31 removed

Extracting mini's properties of stream /streams/stream-6


11 Regions removed so far: 100%|██████████| 396/396 [00:01<00:00, 238.63it/s]


Computing centroids of events...
There is 385 detected events, 11 removed


## Fine Tuning Network
We will fine tune a pre-trained network and save it. 
We will also specified which stream we will be using to fine tune the network.
We need to give the parameters of the method of extraction of Positive and Unlabeled and
the parameters of the fine tuning itself. For the fine tuning parameters, any missing 
parameters are taken from the training configuration.
**This step can take a while depending on the number of steps.**

In [23]:
FINETUNE_NETWORK = False
PRE_TRAINED_NETWORK_PATH = Path("/home/gabriel/results/atd/unet2d/unet2d-foldB-2fgr") # the pre-trained network
FINE_TUNED_NETWORK_PATH = Path("/home/gabriel/results/fine-tuned/sted/test/") # the new fine-tuned network
DRY_RUN = True # If dry_run is true, the fine-tuning wont save any checkpoints (for testing purposes)
STREAM_CONDITION = 'sted'

# P and U extraction parameters
CROP_SIZE = 64
FOREGROUND_RATIO = 0
NUMBER_OF_POSITIVE = 500
SHAPE2CROP = [5, 64, 64]
PATTERN = [-2, -1, 0, 1, 2]

# Fine-tuning parameters
NUM_STEPS = 10000

In [None]:
db = database.Database(DATABASE_FILE)
if FINETUNE_NETWORK:
    # Create the folder which will contain the fine tuned network
    if FINE_TUNED_NETWORK_PATH.exists():
        shutil.rmtree(FINE_TUNED_NETWORK_PATH)
    shutil.copytree(PRE_TRAINED_NETWORK_PATH, FINE_TUNED_NETWORK_PATH)
    if dry_run:
        print("Dry run is set, the fine tuning wont be saved...")

    # Define the config dictionnary from the paramaters
    configs = {
        "dataset_config": {
            "number-of-positive": NUMBER_OF_POSITIVE,
            "foreground-ratio": FOREGROUND_RATIO,
            "crop-size": CROP_SIZE,
            "pattern": PATTERN
        },
        "finetuning_config": {
            "num_steps": NUM_STEPS,
        }
    }
    
    # Define the network experiment wrapper
    pretrained_xp = Experiment(str(PRE_TRAINED_NETWORK_PATH), save=False, verbose=False)
    finetuned_xp = Experiment(str(FINE_TUNED_NETWORK_PATH), save=(not dry_run), verbose=False)
    
    # Define the fine tuning training streams
    train_data = db.filter_by('stream-condition', 'train')

    # Define the fine tuning validation streams
    valid_data = db.filter_by('stream-condition', 'valid')
    
    # Do the actual fine tuning
    db.fine_tune_network(train_data, valid_data, finetuned_xp, configs)
    
db.close()

## Visualization
Let's see how to fine-tuned network perform vs the pre-trained and Mini-Finder
We will be using the stream with the 'stream-condition' 'to_test_fine_tuning'

In [44]:
STREAM_CONDITION = 'to_test_fine_tuning'
PRE_TRAINED_NETWORK_PATH = Path("/home/gabriel/results/atd/unet2d/unet2d-foldB-2fgr") # the pre-trained network
FINE_TUNED_NETWORK_PATH = Path("/home/gabriel/results/fine-tuned/sted/test/") # the new fine-tuned network
FINE_TUNE_FILENAME = 'mask_finetuned.tif'
PRE_TRAINED_FILENAME = 'mask_pretrained.tif'
MINI_FINDER_FILENAME = 'mask_minifinder.tif'
STREAM_FILENAME = 'stream.tif'

In [42]:
db = database.Database(DATABASE_FILE)
test_streams = db.filter_by('stream-condition', STREAM_CONDITION)
assert len(test_streams) == 2 # There should be 2 streams
with_mask, without_mask = test_streams

# Define the network experiment wrapper
pretrained_xp = Experiment(str(PRE_TRAINED_NETWORK_PATH), save=False, verbose=False)
finetuned_xp = Experiment(str(FINE_TUNED_NETWORK_PATH), save=(not dry_run), verbose=False) 

# We will start by infering with the pretrained network and save the results
db.segment_and_detect_msct(with_mask, pretrained_xp, redo=True)
tifffile.imsave(PRE_TRAINED_FILENAME, with_mask['segmentation'][...].astype(np.int16))

# We then infer with the finetuned network and save the results
db.segment_and_detect_msct(with_mask, finetuned_xp, redo=True)
tifffile.imsave(FINE_TUNE_FILENAME, with_mask['segmentation'][...].astype(np.int16))

# We save the minifinder mask for convinience
tifffile.imsave(MINI_FINDER_FILENAME, db.load_mini_mask(with_mask).astype(np.int16))

# We also save the stream for convinience
tifffile.imsave(STREAM_FILENAME, db.load_raw_stream(with_mask).astype(np.int16))