In [3]:
from __future__ import print_function
import torch

import json
import time
from datetime import date, datetime

from pathlib import Path
import shutil
import os

import numpy as np
import pandas as pd

import requests
import itertools
from collections import OrderedDict

import skimage.transform as skt
import skimage

from random import sample
from scipy.special import softmax

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

# brtdevkit
import brtdevkit
from brtdevkit.core.db import DBConnector
from brtdevkit.core.db.db_filters import *  # We need this for pre-defined filters, e.g., ProjectFilter, DatetimeFilter
import brtdevkit.util.s3 as brt_s3

# dl-core
import brtdl.metrics
from brtdl.metrics.evaluate_metrics import evaluate_metrics
from brtdl import inference, default_arguments
from brtdl.data import canonical_types, SampleKeys
from brtdl.data.dataset import PathDataset, PredictionDataset
from brtdl.data.loaders import im_reader, npy_reader
from brtdl.transforms import SegmentationTransform
from brtdl.visual import colorize_segmentation

from warnings import filterwarnings
filterwarnings("ignore")

In [4]:
# Find the latest models on the 2021 Model Tracking Page: 
# https://bluerivertechnology.atlassian.net/wiki/spaces/SNS/pages/1865318625/2021+Model+Tracker

# Models need to have one output, not the image quality multi-headed outputs (dust, blur, etc)
# Can be a locally stored file or url
model_paths = {'20210317_1_corn_4':'/home/williamroberts/code/brtdevkit/Projects/Inference Images/models/20210317_1_corn_4.jit',
               '20210308_1_soybeans_4' : '/home/williamroberts/code/brtdevkit/Projects/Inference Images/models/20210308_1_soybeans_4.jit',
               '20210317_1_cotton_4':  '/home/williamroberts/code/brtdevkit/Projects/Inference Images/models/20210317_1_cotton_4.jit',
               '20210220_1_fallow_4' : '/home/williamroberts/code/brtdevkit/Projects/Inference Images/models/20210220_1_fallow_4.jit',
               'allSoy_0': '/home/williamroberts/code/brtdevkit/Projects/Inference Images/models/allSoy_0.jit',
               '20200528_1_soybeans': 'https://artifactory.bluerivertech.com/artifactory/dev-shashta-models-local/jit/20200528_1_soybeans/20200528_1_soybeans.jit'
                }


## Whole Lot of Functions to Facilitate Inference

In [5]:
# Functions to download inputs and run inference

def download_images(df, artifact_kind, outputdir):
    """
    Helper function to download images from S3 given a dataframe.
    :param df: dataframe, expected to be created by a query to brtdevkit
    :param artifact_kind: string
    :param outputdir: output directory to store the downloaded images
    """
    s3_client = brt_s3.S3()
    #print(f'Downloading {len(df)} source images...')
    for ix, row in df.iterrows():
        art = [x for x in row.artifacts if x['kind'] == artifact_kind]
        if art:
            art = art[0]
            fname = os.path.join(outputdir, os.path.basename(art['s3_key']))
            s3_client.download_file(art['s3_bucket'], art['s3_key'], fname)
    print('Source image download complete')
    
def download_from_url(url, fname):
    """
    Helper function to download data given an URL
    :param url: string, url of the file to be downloaded
    :param fname: filename of the downloaded file
    """
    r = requests.get(url, allow_redirects=False)
    with open(fname, 'wb') as fd:
        for chunk in r.iter_content(chunk_size=128):
            fd.write(chunk)

def get_id_from_s3(df, key):
    """
    Helper function to retrieve an image id given a s3 key.
    """
    found = False
    for ix, row in df.iterrows():
        art = [x for x in row.artifacts if x['kind'] == 'nrg']
        if art:
            art = art[0]
            #print(art['s3_key'][19:])
            if art['s3_key'][19:] == key:
                print(row['_id']) 
                found = True
    if found == False:
        print('Key not found')
            
            
def run_inference_local(img_dir, model_path, result_dir):
    """
    Use this function if the images and model are already downloaded and locally available.
    Inputs: 
    img_dir: a directory with png images (should be NRG888) to analyze
    model_path: the locally stored model you want to use
    result_dir: a directory to store the finished result images. Does not need to exist already.
    """
        
    if os.path.exists(result_dir)==False:
        os.mkdir(result_dir)
    
    channel_order = 'HWC'
    batch_size = 1
        
    resize_shape = (540, 960)#(model_metadata['input_shape'][2], model_metadata['input_shape'][3])
    with inference.managed_indir(default_arguments.PathTypes.abspath(img_dir)) as data_dir:
        suffix = inference.infer_suffix(data_dir, '')
        dset = inference.dset_dispatcher[suffix](data_dir, suffix, resize_shape, channel_order)

    # Run inference on downloaded images
    pred_gen = inference.predict_jit(dset, default_arguments.PathTypes.abspath(model_path), batch_size, device='cpu')
    try:
        for img_id, f_hat in pred_gen:
            y_hat = np.argmax(f_hat, axis=0).astype(np.ubyte)
            img = mpimg.imread(os.path.join(img_dir, img_id + suffix)) * 255.0
            img = img.astype(np.uint8)
            h, w, c = img.shape
            resized_yhat = skimage.transform.resize(y_hat, (h, w), order=0,
                                                    mode="constant", anti_aliasing=False, 
                                                    preserve_range=True)
            resized_yhat = resized_yhat.astype(np.uint8)
            colorized_img = colorize_segmentation(img, resized_yhat)
            out_path = os.path.join(result_dir, img_id[:36] + ".jpg")
            skimage.io.imsave(out_path, colorized_img)
    except:
        for img_id, f_hat, heads in pred_gen:
            y_hat = np.argmax(f_hat, axis=0).astype(np.ubyte)
            img = mpimg.imread(os.path.join(img_dir, img_id + suffix)) * 255.0
            img = img.astype(np.uint8)
            h, w, c = img.shape
            resized_yhat = skimage.transform.resize(y_hat, (h, w), order=0,
                                                    mode="constant", anti_aliasing=False, 
                                                    preserve_range=True)
            resized_yhat = resized_yhat.astype(np.uint8)
            colorized_img = colorize_segmentation(img, resized_yhat)
            out_path = os.path.join(result_dir, img_id[:36] + ".jpg")
            skimage.io.imsave(out_path, colorized_img)
        
def return_prediction(img_dir, model_path):
    """
    Use this function to run inference and return just the prediciton for a comparison.
    Inputs: 
    img_dir: a directory with png images (should be NRG888) to analyze
    model_path: the location of the model you want to use
    """
    
    channel_order = 'HWC'
    batch_size = 1
        
    resize_shape = (540, 960)#(model_metadata['input_shape'][2], model_metadata['input_shape'][3])
    with inference.managed_indir(default_arguments.PathTypes.abspath(img_dir)) as data_dir:
        suffix = inference.infer_suffix(data_dir, '')
        dset = inference.dset_dispatcher[suffix](data_dir, suffix, resize_shape, channel_order)

    # Run inference on downloaded images
    pred_gen = inference.predict_jit(dset, default_arguments.PathTypes.abspath(model_path), batch_size, device='cpu')
    for img_id, f_hat in pred_gen:
        y_hat = np.argmax(f_hat, axis=0).astype(np.ubyte)
        img = mpimg.imread(os.path.join(img_dir, img_id + suffix)) * 255.0
        img = img.astype(np.uint8)
        h, w, c = img.shape
        resized_yhat = skimage.transform.resize(y_hat, (h, w), order=0,
                                                mode="constant", anti_aliasing=False, 
                                                preserve_range=True)
        resized_yhat = resized_yhat.astype(np.uint8)
        return resized_yhat
        
def run_inference_from_web(image_initial, model_names, show_model_metadata=False):
    """
    Inputs: 
    image_initial: a pandas Dataframe queried from brtdevkit (where you can access the image s3 keys)
    model_names: a list containing model names from the soy_models dictionary

    Outputs, contained in separate directories:
    Downloaded NRG Images
    Model downloaded from local file or artifactory url
    Results images generated for each model
    """
    
    # Throw an error if the input DF is empty
    assert len(image_initial) > 0, 'Input DataFrame is empty. Nothing to analyze.'
    
    # Throw an error if more than one crop is represented in the DataFrame
    assert len(image_initial.crop_name.unique())==1, 'Multiple crop_names in DataFrame'
    
    today = date.today()
    out_name = f'{analysis_name}_{str(today).replace("-", "_")}'
    # '/home/williamroberts/code/brtdevkit/Projects/Inference Images/0528/CiaramaFazendas_4_SOYBEANS/nrgimages'
    out_dir = Path('') / Path(out_name)
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    out_dir.mkdir(parents=False, exist_ok=True)
    # Create directories for base images to analyze
    img_dir = Path('') / Path(out_name) / Path('nrg_images')
    img_dir.mkdir(parents=False, exist_ok=True)
    
    # Download Images To Infer
    download_images(image_initial, 'nrg', img_dir)
    
    for m in model_names:
        model_url = soy_models[m]
        model_name = os.path.basename(model_url)
        # Create out_dir and model_dir
        model_dir = out_dir 
        model_dir.mkdir(parents=False, exist_ok=True)
        result_dir = model_dir / Path(m+'_result_images')
        result_dir.mkdir(parents=False, exist_ok=True)
        
        if show_model_metadata:
            model_name = os.path.basename(model_url)
            metadata_name = 'metadata.json'
            metadata_url = os.path.join(os.path.dirname(model_url), 'metadata.json')
            metadata_path = model_dir / Path(metadata_name)
            download_from_url(metadata_url, metadata_path)
            with open(metadata_path, 'r') as f:
                model_metadata = json.load(f)
                print(m+' metadata.json:')
                print(model_metadata)
                print(' ')

        # TODO: Will need to change this to be available for multple crop_names, not just soy
        # Download model
        model_path = model_dir / Path(m + '.jit')
        if model_url[0:4] =='http':
            download_from_url(model_url, model_path)
        else:
            model_path = soy_models[m]
        run_inference_local(img_dir, model_path, result_dir)
        print(f'Inference completed for {m}')

def softmax_inference_local(img_dir, model_path, result_dir, threshold):
    """
    Use this function if the images and model are already downloaded and locally available.
    Inputs: 
    img_dir: a directory with png images (should be NRG888) to analyze
    model_path: the location of the model you want to use
    result_dir: a directory to store the finished result images. Does not need to exist already.
    
    TESTED: Works great!
    """
        
    if os.path.exists(result_dir)==False:
        os.mkdir(result_dir)
    
    channel_order = 'HWC'
    batch_size = 1
        
    resize_shape = (540, 960)#(model_metadata['input_shape'][2], model_metadata['input_shape'][3])
    with inference.managed_indir(default_arguments.PathTypes.abspath(img_dir)) as data_dir:
        suffix = inference.infer_suffix(data_dir, '')
        dset = inference.dset_dispatcher[suffix](data_dir, suffix, resize_shape, channel_order)

    # Run inference on downloaded images
    pred_gen = inference.predict_jit(dset, default_arguments.PathTypes.abspath(model_path), batch_size, device='cpu')
    for img_id, f_hat, heads in pred_gen:
        x = softmax(f_hat, axis=0)
        y_hat = np.greater(x[2, :, :], x[0, :, :]) * 2
        y_hat[x[1, :, :] > threshold] = 1
        y_hat = y_hat.astype(np.ubyte)
        img = mpimg.imread(os.path.join(img_dir, img_id + suffix)) * 255.0
        img = img.astype(np.uint8)
        h, w, c = img.shape
        resized_yhat = skimage.transform.resize(y_hat, (h, w), order=0,
                                                mode="constant", anti_aliasing=False, 
                                                preserve_range=True)
        resized_yhat = resized_yhat.astype(np.uint8)
        colorized_img = colorize_segmentation(img, resized_yhat)
        out_path = os.path.join(result_dir, img_id + ".jpg")
        skimage.io.imsave(out_path, colorized_img)
        
def generate_evaluations(base_dir, fields, n_images = 50):
    """
    Model names and analysis name should be set up in the previous cell.
    Input is list of grower_farm_field names to run inference on and store results
    """
    # assign models by crop
    for f in fields:
        field_df = full_df[full_df['grower_farm_field']==f]
        if field_df.crop_name.unique()[0] == 'CORN':
            model = corn_model 
        elif field_df.crop_name.unique()[0] == 'SOYBEANS':
            model = soy_model
        elif field_df.crop_name.unique()[0] == 'COTTON':
            model = cotton_model
        elif field_df.crop_name.unique()[0] == 'NONE_FALLOW_PRE_EMERGE':
            model = fallow_model
        elif field_df.crop_name.unique()[0] == 'OTHER':
            model = fallow_model
        else:
            print(f'Unrecognized crop in {f}')
            break

        # Create directories for analysis, images
        if os.path.exists(base_dir + analysis_name)==False:
            os.mkdir(base_dir + analysis_name)
        
        analysis_dir = base_dir + analysis_name + '/' + str(f)
        nrg_dir = analysis_dir + '/nrg_images'
        
        if os.path.exists(analysis_dir)==False:
            os.mkdir(analysis_dir)
        if os.path.exists(nrg_dir)==False:
            os.mkdir(nrg_dir)
        elif os.path.exists(nrg_dir)==True:
            shutil.rmtree(nrg_dir)
            os.mkdir(nrg_dir)

        model_path = model_paths[model]
        
        # Choose images and download
        image_sample = sample(field_df._id.to_list(), n_images)
        sparse = field_df[field_df['_id'].isin(image_sample)]
        download_images(sparse, 'nrg', nrg_dir)

        # create directory for results
        results_dir = analysis_dir + '/' + str(model) + '_results'

        # Save results in the result dir
        run_inference_local(nrg_dir, model_path, results_dir)
        print(f'Finished with {f}')

In [6]:
# Query dataframe of fields to evaluate

def get_shasta_data(filters={}, start=None, end=None, limit=None):
    """
    Query relevant Shasta data for calculations. 
    """
    start_time = time.time()
    connector = DBConnector()
    img_filters = {'project_name': 'shasta', **filters}
    if start is not None or end is not None:
        img_filters = [img_filters, DatetimeFilter(key="collected_on", start=start, end=end)]
    df = connector.get_documents_df('image', img_filters, limit=limit)
    elapsed_time = time.time() - start_time
    return df, elapsed_time

# Set start and end dates for query
start = datetime(2021, 4, 12)
end = datetime(2021, 5,31)

# Lists of DCMs and Machines
dcms = ['DCM-MANATEE', 'DCM-WALRUS', 'DCM-SEAL', 'DCM-OTTER', 'DCM-PORPOISE', 'DCM-DOLPHIN']
dcms_2021 = ['DCM11', 'DCM12', 'DCM13', 'DCM14', 'DCM16']
machines = ["SHASTA-FB-BRADLEY","SHASTA-FB-PALADIN", "BLACKBIRD", 'ATM-DUCKDUCK', 'ATM-GOOSE']
valid_isp_versions = ['07080203','07090000','07090100']

# Select filters
filters = { "artifacts.kind": "nrg",  
           'robot_name' : {'$in':dcms_2021}
          }

full_df, elapsed_time = get_shasta_data(filters=filters, start = start, end = end)
full_df['date_collected'] = pd.to_datetime(full_df['collected_on']).dt.date
full_df['grower_farm_field'] = full_df['grower'] +'_' + full_df['farm'] + '_' + full_df['operating_field_name']
print(f"Queried {len(full_df)} images in {elapsed_time:.2f} s.")

Queried 87870 images in 43.45 s.


In [8]:
# Choose a name for your analysis
analysis_name = 'Early_MAY'
today = date.today()
analysis_name = f'{analysis_name}_{str(today).replace("-", "_")}'
print('Results will be stored in:', analysis_name)

# models to use in inference
corn_model = '20210317_1_corn_4'
soy_model = '20210308_1_soybeans_4'
cotton_model = '20210317_1_cotton_4'
fallow_model = '20210220_1_fallow_4'

base_dir = '/home/williamroberts/code/brtdevkit/Projects/Inference Images/' 

Results will be stored in: Early_MAY_2021_05_07


In [None]:
# generate inference and save locally

generate_evaluations(base_dir=base_dir,  fields = full_df.grower_farm_field.unique(),n_images = 50)