# Segment 2p - Simple Training

## Instructions

### Step 1: Organize training dataset

The model accepts files organized into the following directory structure:

..... Train   
..... Train_Annotation   
..... Validation   
..... Validation_Annotation

In the train directory there should be jpg files with the micrographs for training. In train annotation, ground truth segmentation should appear as PNGs with the same name as their respective image in the train folder. These ground truths should have a value of zero for regions that are _not_ of interest and one for regions of interest. Same goes for validation and validation_annotation. These files should appear in an S3 bucket. 

### Step 2: Setup model for training

To setup a model for training we first, setup an estimator object: 
```{Python}
ss_model = sagemaker.estimator.Estimator(...
```
Then we set the hyperparameters:
```Python}
ss_model.set_hyperparameters(backbone='resnet-101', # This is the encoder. Other option is resnet-50
                             algorithm='deeplab', # This is the decoder. Other option is 'psp' and 'deeplab'    
                             use_pretrained_model='False'...
```
See [Training](#Training) for a full demonstration.    

### Step 3: Run model training on data
Now we feed the model the data for training by setting up a dictionary as follows:   
```Python}
data_channels = {'train': train_data, 
                 'validation': validation_data,
                 'train_annotation': train_annotation, 
                 'validation_annotation':validation_annotation}
```

## Imports

In [None]:
import sys

import sagemaker
from sagemaker import get_execution_role

import matplotlib.pyplot as plt
import PIL
from PIL import Image
import io
import boto3
import numpy as np
from skimage import util 
from skimage.util import img_as_ubyte
from skimage import exposure
from skimage.io import imread as pngread
from skimage.io import imsave as pngsave
from skimage.segmentation import mark_boundaries
from skimage import color

import cv2
from rolling_ball_filter import rolling_ball_filter
import random
import threading

from processfiles import *
role = get_execution_role()
print(role)
sess = sagemaker.Session()
bucket = sess.default_bucket()
from sagemaker.amazon.amazon_estimator import get_image_uri
training_image = get_image_uri(sess.boto_region_name, 'semantic-segmentation', repo_version="latest")
print (training_image)

s3 = boto3.resource('s3')
s3_resource = boto3.resource('s3')
s3meadata = s3_resource.Bucket(name='meadata')

-------
_Setup of our original training set_


## Setup data

In [None]:
%%capture
# Run process functions (raw and filtered versions of fig8 and liorP)
def procfilepar(key):
    proccessliorpreprocfiles(key)
    proccessliorfiles(key)
    proccessfigure8files(key)
    proccessfig8preprocfiles(key)
    proccessusiigacifiles(key)
    proccesshelafiles(key)
    
keys = [obj.key for obj in s3meadata.objects.all()]
for key in keys:
    t = threading.Thread(target = procfilepar, args=(key,)).start()

### Crop dataset images around labeled areas

In [None]:
keys = [obj.key for obj in s3_resource.Bucket(name=bucket).objects.all() if ('jpg' in obj.key and prefix in obj.key)]
for key in keys:
     t = threading.Thread(target = performcrop, args=(key,)).start()

### Delete all files without a matching image-annotation

In [None]:
removeunmatched()

### Remove samples with few acceptable ground truth segmentations

In [None]:
files = []
train_channel = prefix + '/train'
validation_channel = prefix + '/validation'
train_annotation_channel = prefix + '/train_annotation'
validation_annotation_channel = prefix + '/validation_annotation'
keys = [obj.key for obj in s3_resource.Bucket(name=bucket).objects.all() if ('png' in obj.key and prefix in obj.key)]
segs = []
empties = []
for key in keys:
    masksavepath = "/tmp/"+key.split('/')[-1]
    s3.meta.client.download_file(bucket, key , masksavepath)
    mask = cv2.imread(masksavepath)
    segs.append([np.sum(mask==1)])
    empties.append([np.sum(mask==0)])

ratio = ((np.asarray(segs)/np.asarray(empties))*100).ravel()
thresh = np.round(np.mean(ratio)-np.std(ratio))
# plt.hist(ratio)
# plt.show()
df = pd.DataFrame({'key':keys, 'ratio':ratio,'empty':ratio<thresh})
removesamples = df['key'].loc[np.where(df['empty'].values)].values
for removeme in removesamples:
    boto3.client('s3').delete_object(Bucket = bucket, Key = removeme)
    boto3.client('s3').delete_object(Bucket = bucket, Key = removeme.replace('_annotation/','/').replace('png','jpg'))

--------------------------------
_Finished setup of training set from article_


## Image types and output location

In [None]:
import json
label_map = { "scale": 1 }
with open('train_label_map.json', 'w') as lm_fname:
    json.dump(label_map, lm_fname)

In [None]:
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)
print(s3_output_location)

## Training

### Setup Model Hyperparameters

In [None]:
# Create the sagemaker estimator object.
ss_model = sagemaker.estimator.Estimator(training_image,
                                         role, 
                                         train_instance_count = 1, 
                                         train_instance_type = 'ml.p3.16xlarge',
                                         train_volume_size = 300, # size in gb on s3 to reserve
                                         train_max_run = 360000,
                                         output_path = s3_output_location,
                                         base_job_name = 'segment2p_train',
                                         sagemaker_session = sess)

In [None]:
# Setup hyperparameters 
import boto3
s3traindata = boto3.resource('s3').Bucket(name=bucket)
numtrain = len([obj.key for obj in s3traindata.objects.all() if ('train/' in obj.key and 'jpg' in obj.key)])
ss_model.set_hyperparameters(backbone='resnet-101', # This is the encoder. Other option is resnet-50
                             algorithm='deeplab', # This is the decoder. Other option is 'psp' and 'deeplab'                             
                             use_pretrained_model='False', # Use the pre-trained model.
                             crop_size=412, # Size of image random crop.                             
                             num_classes=2, # Background + cell 
                             epochs=1000, # Number of epochs to run.
                             learning_rate=0.003037052721870563, 
                             momentum = 0.6133596510181524, 
                             weight_decay = 0.0001560844683426084,                           
                             optimizer='adagrad', # Other options include 'adam', 'rmsprop', 'nag', 'adagrad'.
                             lr_scheduler='poly', # Other options include 'cosine' and 'step'.                           
                             mini_batch_size=35, # Setup some mini batch size.
                             validation_mini_batch_size=16, #try larger batch sizes maybe? 
                             early_stopping=True, # Turn on early stopping. If OFF, other early stopping parameters are ignored.
                             early_stopping_patience=50, # Tolerate these many epochs if the mIoU doens't increase.
                             early_stopping_min_epochs=25, # No matter what, run these many number of epochs.                             
                             num_training_samples=numtrain) 

### Setup data inputs

In [None]:
# Create full bucket names
s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)
s3_train_annotation = 's3://{}/{}'.format(bucket, train_annotation_channel)
s3_validation_annotation = 's3://{}/{}'.format(bucket, validation_annotation_channel)

distribution = 'FullyReplicated'
# Create sagemaker s3_input objects
train_data = sagemaker.session.s3_input(s3_train_data, distribution=distribution, 
                                        content_type='image/jpeg', s3_data_type='S3Prefix')
validation_data = sagemaker.session.s3_input(s3_validation_data, distribution=distribution, 
                                        content_type='image/jpeg', s3_data_type='S3Prefix')
train_annotation = sagemaker.session.s3_input(s3_train_annotation, distribution=distribution, 
                                        content_type='image/png', s3_data_type='S3Prefix')
validation_annotation = sagemaker.session.s3_input(s3_validation_annotation, distribution=distribution, 
                                        content_type='image/png', s3_data_type='S3Prefix')

data_channels = {'train': train_data, 
                 'validation': validation_data,
                 'train_annotation': train_annotation, 
                 'validation_annotation':validation_annotation}

### Fit model 

In [None]:
%%capture
ss_model.fit(inputs=data_channels, logs=True)