# Fine Tuning

We will be using this notebook to show how to fine tune the network with new data

**User inputs are capitalize throughout this notebook**

In [1]:
# imports
import sys
sys.path.append("/home/gabriel/dev/automatic-transient-detection")
from pathlib import Path
import shutil
import collections
import numpy as np
from tqdm import tqdm

# DataBase handler
import database

# Network handler
from ATD.common.experiment import Experiment

# Various utils functions
from ATD.utils import ioutils as iou
from ATD.utils import arrayutils as au
from ATD.utils import calciumutils as cu
from ATD.utils import listutils as lu
from ATD.utils import datasetutils as du

# because I am editing this file for now
%load_ext autoreload
%autoreload 1
%aimport database

## Database Creation
The following section provides the code to create the databse from a [google spreadsheet](https://docs.google.com/spreadsheets/d/12tEG3494V3y24qU8f11Vzui85sqt3VTwUzHBfEl0sQA/edit#gid=409641959)


In [2]:
CREATE_DATABASE = True
DATABASE_FILE = Path("../h5file/finetune_test2.h5") # location of the database
SPREADSHEET = 'Raw Data List' # title of the spreadsheet document
WORKSHEET = 'FINE_TUNING_STED' # title of the worksheet in the spreadsheet
secretJSON = '../bot-reader-key.json' # file containing the authorizations to read the spreadsheet

In [3]:
# Get a worksheet object we will read with
worksheet = database.Database.oauth_and_get_worksheet(SPREADSHEET, WORKSHEET, secretJSON)

# Remove an existing database if it exists and CREATE_DATABASE is True
if CREATE_DATABASE and DATABASE_FILE.exists():
    DATABASE_FILE.unlink()
    
# Instanciate the database object and populate the databse from the worksheet
db = database.Database(DATABASE_FILE)
if CREATE_DATABASE:
    db.load_streams_from_worksheet(worksheet)
db.close()

Creating a new database at ../h5file/finetune_test2.h5...
Adding the stream
Imaging of neuron 635 on coverslip 0
Experiment 100klam with stream condition train
The stream Stream1_Cropped.tif was image on 2019-05-17 00:00:00
Unique stream-id is: stream-0
Adding the stream
Imaging of neuron 636 on coverslip 0
Experiment 100klam with stream condition train
The stream Stream2_Cropped.tif was image on 2019-05-17 00:00:00
Unique stream-id is: stream-1
Adding the stream
Imaging of neuron 576 on coverslip 3
Experiment NMDA with stream condition train
The stream Stream1_Aligned.tif was image on 2020-06-03 00:00:00
Unique stream-id is: stream-2
Adding the stream
Imaging of neuron 576 on coverslip 3
Experiment NMDA with stream condition train
The stream Stream2_Aligned.tif was image on 2020-06-03 00:00:00
Unique stream-id is: stream-3
Adding the stream
Imaging of neuron 577 on coverslip 4
Experiment NMDA with stream condition valid
The stream Stream1_Aligned.tif was image on 2020-06-03 00:00:00
U

## MiniFinder mSCTs Region Props Extraction
Since we are using the MiniFinder analysis software (the mask location are listed in the column Mini-Finder in the worksheet), we need to extract the properties of the segmented regions from the matrix.

In [6]:
EXTRACT_MINI_REGION_PROPS = True
REDO = True # if redo is True, if the properties already exist, they will be recalculated
MINI_REGIONPROPS_PARAMS = {
    'msct-props': [
        'area', 
        'convex_area',
        'bbox',
        'max_intensity',
        'major_axis_length',
        'minor_axis_length'
    ],
    'minimal-shape': {
        'minimal_time': 2,
        'minimal_height': 3,
        'minimal_width': 3
    }
}

In [7]:
# For this specific example, we dont have the mask of the last stream, 
# so we wont be extracting props of this stream
db = database.Database(DATABASE_FILE)
if EXTRACT_MINI_REGION_PROPS:
    streams = list(db)
    streams = [s for s in streams if s.attrs['stream-condition'] != 'test_no_mask']
    db.mini_regionprops_extraction(streams, MINI_REGIONPROPS_PARAMS, redo=REDO)
print('Done')
db.close()


Extracting mini's properties of stream /streams/stream-0


16 Regions removed so far: 100%|██████████| 323/323 [00:01<00:00, 163.13it/s]


Computing centroids of events...
There is 307 detected events, 16 removed

Extracting mini's properties of stream /streams/stream-1


9 Regions removed so far: 100%|██████████| 135/135 [00:00<00:00, 228.11it/s]


Computing centroids of events...
There is 126 detected events, 9 removed

Extracting mini's properties of stream /streams/stream-2


26 Regions removed so far: 100%|██████████| 736/736 [00:03<00:00, 192.22it/s]


Computing centroids of events...
There is 710 detected events, 26 removed

Extracting mini's properties of stream /streams/stream-3


24 Regions removed so far: 100%|██████████| 642/642 [00:03<00:00, 186.77it/s]


Computing centroids of events...
There is 618 detected events, 24 removed

Extracting mini's properties of stream /streams/stream-4


56 Regions removed so far: 100%|██████████| 708/708 [00:07<00:00, 92.42it/s] 


Computing centroids of events...
There is 652 detected events, 56 removed

Extracting mini's properties of stream /streams/stream-5


31 Regions removed so far: 100%|██████████| 431/431 [00:04<00:00, 96.72it/s] 


Computing centroids of events...
There is 400 detected events, 31 removed

Extracting mini's properties of stream /streams/stream-6


11 Regions removed so far: 100%|██████████| 396/396 [00:01<00:00, 240.48it/s]


Computing centroids of events...
There is 385 detected events, 11 removed

Extracting mini's properties of stream /streams/stream-8


27 Regions removed so far: 100%|██████████| 346/346 [00:03<00:00, 96.53it/s]


Computing centroids of events...
There is 319 detected events, 27 removed

Extracting mini's properties of stream /streams/stream-9


4 Regions removed so far: 100%|██████████| 54/54 [00:00<00:00, 98.43it/s]


Computing centroids of events...
There is 50 detected events, 4 removed

Extracting mini's properties of stream /streams/stream-10


6 Regions removed so far: 100%|██████████| 94/94 [00:00<00:00, 115.02it/s]


Computing centroids of events...
There is 88 detected events, 6 removed
Done


## Fine Tuning Network
We will fine tune a pre-trained network and save it. 
We will also specified which stream we will be using to fine tune the network.
We need to give the parameters of the method of extraction of Positive and Unlabeled and
the parameters of the fine tuning itself. For the fine tuning parameters, any missing 
parameters are taken from the training configuration.
**This step can take a while depending on the number of steps.**

In [8]:
FINETUNE_NETWORK = True
PRE_TRAINED_NETWORK_PATH = Path("/home/gabriel/results/atd/unet2d/unet2d-foldB-2fgr") # the pre-trained network
FINE_TUNED_NETWORK_PATH = Path("/home/gabriel/results/fine-tuned/sted/test_2/") # the new fine-tuned network
DRY_RUN = False # If dry_run is true, the fine-tuning wont save any checkpoints (for testing purposes)

# P and U extraction parameters
CROP_SIZE = 64
FOREGROUND_RATIO = 0
NUMBER_OF_POSITIVE = 750
SHAPE2CROP = [5, 64, 64]
PATTERN = [-2, -1, 0, 1, 2]

# Fine-tuning parameters
NUM_STEPS = 10000

In [9]:
db = database.Database(DATABASE_FILE)
if FINETUNE_NETWORK:
    # Create the folder which will contain the fine tuned network
    if FINE_TUNED_NETWORK_PATH.exists():
        shutil.rmtree(FINE_TUNED_NETWORK_PATH)
    shutil.copytree(PRE_TRAINED_NETWORK_PATH, FINE_TUNED_NETWORK_PATH)
    if DRY_RUN:
        print("Dry run is set, the fine tuning wont be saved...")

    # Define the config dictionnary from the paramaters
    configs = {
        "dataset_config": {
            "number-of-positive": NUMBER_OF_POSITIVE,
            "foreground-ratio": FOREGROUND_RATIO,
            "crop-size": CROP_SIZE,
            "pattern": PATTERN
        },
        "finetuning_config": {
            "num_steps": NUM_STEPS,
        }
    }
    
    # Define the network experiment wrapper
    pretrained_xp = Experiment(str(PRE_TRAINED_NETWORK_PATH), save=False, verbose=False)
    finetuned_xp = Experiment(str(FINE_TUNED_NETWORK_PATH), save=(not DRY_RUN), verbose=False)
    
    # Define the fine tuning training streams
    train_data = db.filter_by('stream-condition', 'train')

    # Define the fine tuning validation streams
    valid_data = db.filter_by('stream-condition', 'valid')
    
    # Do the actual fine tuning
    db.fine_tune_network(train_data, valid_data, finetuned_xp, configs)
    
db.close()

Fine tuning on 7 training streams...
Fine tuning on 3 validation streams...
There is 307 events in stream <HDF5 group "/streams/stream-0" (11 members)>
There is 126 events in stream <HDF5 group "/streams/stream-1" (11 members)>
There is 710 events in stream <HDF5 group "/streams/stream-2" (11 members)>
There is 618 events in stream <HDF5 group "/streams/stream-3" (11 members)>
There is 319 events in stream <HDF5 group "/streams/stream-8" (11 members)>
There is 50 events in stream <HDF5 group "/streams/stream-9" (11 members)>
There is 88 events in stream <HDF5 group "/streams/stream-10" (11 members)>
750/2218 mSCTs will be extracted...
0 foreground will be extracted...
Loading stream 0...


  9%|▉         | 70/750 [00:00<00:00, 696.69it/s]

Loading stream 1...


 13%|█▎        | 94/750 [00:04<00:39, 16.42it/s] 

Loading stream 2...


 46%|████▋     | 347/750 [00:19<00:20, 19.39it/s]

Loading stream 3...


 76%|███████▋  | 572/750 [00:35<00:06, 25.99it/s]

Loading stream 4...


 89%|████████▊ | 665/750 [00:50<00:11,  7.31it/s]

Loading stream 5...


 95%|█████████▍| 709/750 [01:05<00:08,  5.04it/s]

Loading stream 6...


100%|██████████| 750/750 [01:21<00:00,  9.24it/s]


Fine tuning on 750 training crops...
There is 652 events in stream <HDF5 group "/streams/stream-4" (11 members)>
There is 400 events in stream <HDF5 group "/streams/stream-5" (11 members)>
There is 385 events in stream <HDF5 group "/streams/stream-6" (11 members)>
1437/1437 mSCTs will be extracted...
0 foreground will be extracted...
Loading stream 0...


 44%|████▍     | 636/1437 [00:00<00:01, 710.42it/s]

Loading stream 1...


 70%|███████   | 1012/1437 [00:16<00:03, 135.85it/s]

Loading stream 2...


100%|██████████| 1437/1437 [00:32<00:00, 44.45it/s] 


Fine tuning on 1437 validation crops...
[32m[1m [-] Info: The model has 1777537 parameters[0m
[32m[1m [-] Info: Training for 10000 steps[0m


OSError: [Errno 12] Cannot allocate memory

## Visualization
Let's see how to fine-tuned network perform vs the pre-trained and Mini-Finder
We will be using the stream with the 'stream-condition' 'to_test_fine_tuning'

In [44]:
STREAM_CONDITION = 'to_test_fine_tuning'
PRE_TRAINED_NETWORK_PATH = Path("/home/gabriel/results/atd/unet2d/unet2d-foldB-2fgr") # the pre-trained network
FINE_TUNED_NETWORK_PATH = Path("/home/gabriel/results/fine-tuned/sted/test/") # the new fine-tuned network
FINE_TUNE_FILENAME = 'mask_finetuned.tif'
PRE_TRAINED_FILENAME = 'mask_pretrained.tif'
MINI_FINDER_FILENAME = 'mask_minifinder.tif'
STREAM_FILENAME = 'stream.tif'

In [42]:
db = database.Database(DATABASE_FILE)
test_streams = db.filter_by('stream-condition', STREAM_CONDITION)
assert len(test_streams) == 2 # There should be 2 streams
with_mask, without_mask = test_streams

# Define the network experiment wrapper
pretrained_xp = Experiment(str(PRE_TRAINED_NETWORK_PATH), save=False, verbose=False)
finetuned_xp = Experiment(str(FINE_TUNED_NETWORK_PATH), save=(not dry_run), verbose=False) 

# We will start by infering with the pretrained network and save the results
db.segment_and_detect_msct(with_mask, pretrained_xp, redo=True)
tifffile.imsave(PRE_TRAINED_FILENAME, with_mask['segmentation'][...].astype(np.int16))

# We then infer with the finetuned network and save the results
db.segment_and_detect_msct(with_mask, finetuned_xp, redo=True)
tifffile.imsave(FINE_TUNE_FILENAME, with_mask['segmentation'][...].astype(np.int16))

# We save the minifinder mask for convinience
tifffile.imsave(MINI_FINDER_FILENAME, db.load_mini_mask(with_mask).astype(np.int16))

# We also save the stream for convinience
tifffile.imsave(STREAM_FILENAME, db.load_raw_stream(with_mask).astype(np.int16))