#### Instructions

The model can be loaded by first generating a conda environment using the env_for_fern_segmentaion.txt file. Run the following two lines in a terminal first. You must have conda installed. 

In [1]:
# conda create --name herb_segmentation --file env_for_fern_segmentation.txt
# conda activate herb_segmentation

Note: fastai and pytorch are not simple to install. Be prepared to spend some time getting fastai and pytorch in the right configuration according to the env_for_fern_segmentation.txt specifications. This may include downloading older versions of these libraries, as they are in active development and new versions are released frequently. 

In [2]:
import shutil
from PIL import Image as image_save
import itertools
import operator
import fastai
from fastai import *
from fastai.vision import *
from fastai.vision.models.wrn import wrn_22
import functools, traceback
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

## Load the custom classes and the trained model "fern_segmentation.pkl"

In [3]:
class SegLabelListCustom(SegmentationLabelList):
    def open(self, fn): return open_mask(fn, div=True)
    
class SegItemListCustom(SegmentationItemList):
    _label_cls = SegLabelListCustom

In [4]:
import os
from urllib.request import urlretrieve

pickle_url = 'https://www.dropbox.com/s/rwwesulkcbspkhf/fern_segmentation.pkl?dl=1'
if not os.path.exists('fern_segmentation.pkl'):
    urlretrieve(pickle_url, 'fern_segmentation.pkl')

In [5]:
path_to_pickle = '.'

In [6]:
seg_bot = load_learner(path = path_to_pickle,
                      file = 'fern_segmentation.pkl')

## Running a single image through

In [7]:
path_img = Path("examples/example1.jpg") # MUST CHANGE!

In [8]:
img = open_image(path_img)
img_mask_pred = seg_bot.predict(img)

In [9]:
img_mask_pred

(ImageSegment (1, 256, 256),
 tensor([[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]]),
 tensor([[[9.9993e-01, 9.9999e-01, 1.0000e+00,  ..., 9.9998e-01,
           9.9967e-01, 9.9851e-01],
          [9.9998e-01, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
           9.9995e-01, 9.9975e-01],
          [9.9999e-01, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
           9.9999e-01, 9.9994e-01],
          ...,
          [9.9995e-01, 9.9999e-01, 1.0000e+00,  ..., 9.9996e-01,
           9.9970e-01, 9.9586e-01],
          [9.9940e-01, 9.9984e-01, 1.0000e+00,  ..., 9.9971e-01,
           9.9649e-01, 9.6665e-01],
          [9.9241e-01, 9.9678e-01, 9.9978e-01,  ..., 9.9446e-01,
           9.6365e-01, 8.5454e-01]],
 
         [[7.1575e-05, 1.0269e-05, 1.9280e-06,  ..., 1.9599e-05,
           3.3051e-04, 1.4852e-03],
          [2

## Running a large batch of images through and saving the masked versions

In [10]:
path_to_images = 'examples' # MUST CHANGE!

In [11]:
data_test = (ImageList.from_folder(path = path_to_images, 
                                        extensions = ".jpg")
             .split_none()
             .label_empty()).transform(tfms=None, size=256).databunch(bs=64).normalize(imagenet_stats)



                 Your batch size is 64, you should lower it.
  Your batch size is {self.train_dl.batch_size}, you should lower it.""")


In [12]:
bs = 64 # could change this is you are having issues with memory
seg_bot.data.test_dl = data_test.fix_dl

In [13]:
number_of_batches = int(len(seg_bot.data.test_ds)/bs)

In [14]:
path_to_save_masked_images = "examples_masked" # MUST CHANGE!

In [15]:
test_batch_iter = iter(seg_bot.data.test_dl)
test_filenames_iter = iter(seg_bot.data.test_ds.items)
test_images_iter = iter(seg_bot.data.test_ds)

for n in range(number_of_batches):
    batch = next(test_batch_iter)
    preds_tup = seg_bot.pred_batch(batch=batch)
    pred_masks = np.argmax(preds_tup, axis = 1)
    pred_names = array(itertools.islice(test_filenames_iter, bs))
    orig_images = array(itertools.islice(test_images_iter, bs))
    for z in range(bs):
        #print(pred_names[z].parts[-1] + " being masked and output to masked_" + pred_names[z].parts[-1])
        orig_loaded_img = orig_images[z][0].data
        pred_mask = pred_masks[z].unsqueeze(0).double()
        masked = orig_loaded_img.cpu().double() * pred_mask
        mask2 = masked.data.permute(1, 2, 0)
        ndarr = mask2.mul_(255).add_(0.5).clamp_(0, 255).to('cpu', torch.uint8).numpy()
        im = image_save.fromarray(ndarr)
        im.save(path_to_save_masked_images+"/masked_"+pred_names[z].parts[-1])