In [2]:
from __future__ import print_function

import json
import logging
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import shutil
import skimage
from collections import OrderedDict
from datetime import datetime
from mpl_toolkits.axes_grid1 import ImageGrid
from pathlib import Path
from tqdm import tqdm

# Interactive/Widgets
import ipywidgets as widgets
from IPython.core.display import display, HTML
import itertools
from ipywidgets import interact
import matplotlib.pyplot as plt
import qgrid
from ipywidgets import interact, interactive, fixed, interact_manual
import wandb

# Devkit
import brtdevkit
from brtdevkit.core.db import DBConnector
from brtdevkit.core.db.db_filters import *  # We need this for pre-defined filters, e.g., ProjectFilter, DatetimeFilter

# Dl-Core
import brtdl.metrics
from brtdl.metrics.evaluate_metrics import evaluate_metrics
from brtdl.data import canonical_types, SampleKeys
from brtdl.data.dataset import PathDataset, PredictionDataset
from brtdl.data.loaders import im_reader, npy_reader
from brtdl import inference, default_arguments
from brtdl.transforms import SegmentationTransform
from brtdl.visual import colorize_segmentation

from divyas_utils import *

%matplotlib inline
logging.basicConfig(level=logging.INFO)

In [4]:
# select Images to Evaluate in DataFrame 

db = DBConnector()

# Date filters
start_date =datetime(2020, 3, 1)
end_date =datetime(2020, 9, 1)

# Crop name filter, if needed
crop_name = 'SOYBEANS'

img_filters = [
    {
        'project_name': 'shasta',
        "crop_name": 'COTTON',
        'artifacts': {"$elemMatch": {'s3_bucket': {"$exists": True}, 's3_key': {"$exists": True}}},
        "annotations": {"$elemMatch": {"is_active_version": True, "state": "ok", "kind": {"$nin": ["ndvi_mask", "machine"]}, "style": 'pixelwise', 's3_bucket': {"$exists": True}, 's3_key': {"$exists": True},  'label_map': {'$in': [{'1': 'crop', '2': 'weed'}, {'1': 'weed', '2': 'crop'}]}}},
    }, 
    DatetimeFilter(key='collected_on', start=start_date, end=end_date)
]
df = db.get_documents_df('image', img_filters, limit=None)
print(len(df))

58155


In [7]:
air = pd.read_csv('ai_rev_cotton_four_images.csv')
adf = df[df['_id'].isin(air['nrg_id'])]

print("I have a dataframe of length: ", len(adf))


I have a dataframe of length:  296


In [None]:
plt.figure(figsize=(12, 8))
sns.barplot(x="operating_field_name", y="fscore", hue="model", data=means, palette =pal)
plt.title('Mean Gridmetric Instance F1 Score')
plt.savefig('br_by_field_fscore.png')
plt.show()

In [1]:
# Functions to download inputs and run inference
          
def download_assets(directory, df):
    """
    Downloads s3 assets of a given list
    Args:
        directory (Path): the path that the assets will be saved to
        assets (iter(dict)): an iterable containing metadata objects that have s3 assets.
        Will usually be a dataframe column (ex: `image_df['artifacts'].explode()`)
    """
    s3_client = brt_s3.S3()

    def _download(asset, filepath):
        filepath.parent.mkdir(parents=True, exist_ok=True)
        return s3_client.download_file(asset['s3_bucket'], asset['s3_key'], str(filepath))

    for _, row in tqdm(df.iterrows(), total=df.shape[0]):
        datapoint_id = str(row['_id'])
        try:
            annotation = [a for a in row['annotations'] if a['is_active_version'] and a['kind'] != "ndvi_mask" and a['style'] == 'pixelwise'][0]
        except:
            annotation = [a for a in row['annotations'] if a['kind'] in ['f8', 'labelbox', 'brt', 'dataloop']][0]

        artifact = [a for a in row['artifacts'] if a['kind'] == 'nrg'][0]
        _download(artifact, directory / f'{datapoint_id}_nrg.png')
        _download(annotation, directory / f'{datapoint_id}_ann.png')
      

def download_images(df, artifact_kind, outputdir):
    """
    Helper function to download images from S3 given a dataframe.
    :param df: dataframe, expected to be created by a query to brtdevkit
    :param artifact_kind: string
    :param outputdir: output directory to store the downloaded images
    """
    s3_client = brt_s3.S3()
    print(f'Downloading {len(df)} source images...')
    for ix, row in df.iterrows():
        art = [x for x in row.artifacts if x['kind'] == artifact_kind]
        if art:
            art = art[0]
            fname = os.path.join(outputdir, os.path.basename(art['s3_key']))
            s3_client.download_file(art['s3_bucket'], art['s3_key'], fname)
    print('Source image download complete')
    
def download_from_url(url, fname):
    """
    Helper function to download data given an URL
    :param url: string, url of the file to be downloaded
    :param fname: filename of the downloaded file
    """
    r = requests.get(url, allow_redirects=False)
    with open(fname, 'wb') as fd:
        for chunk in r.iter_content(chunk_size=128):
            fd.write(chunk)
            
def run_inference_local(img_dir, model_path, result_dir):
    """
    Use this function if the images and model are already downloaded and locally available.
    Inputs: 
    img_dir: a directory with png images (should be NRG888) to analyze
    model_path: the location of the model you want to use
    result_dir: a directory to store the finished result images. Does not need to exist already.
    """
        
    if os.path.exists(result_dir)==False:
        os.mkdir(result_dir)
    
    channel_order = 'HWC'
    batch_size = 1
        
    resize_shape = (540, 960)#(model_metadata['input_shape'][2], model_metadata['input_shape'][3])
    with inference.managed_indir(default_arguments.PathTypes.abspath(img_dir)) as data_dir:
        suffix = inference.infer_suffix(data_dir, '')
        dset = inference.dset_dispatcher[suffix](data_dir, suffix, resize_shape, channel_order)

    # Run inference on downloaded images
    pred_gen = inference.predict_jit(dset, default_arguments.PathTypes.abspath(model_path), batch_size, device='cpu')
    for img_id, f_hat in pred_gen:
        y_hat = np.argmax(f_hat, axis=0).astype(np.ubyte)
        img = mpimg.imread(os.path.join(img_dir, img_id + suffix)) * 255.0
        img = img.astype(np.uint8)
        h, w, c = img.shape
        resized_yhat = skimage.transform.resize(y_hat, (h, w), order=0,
                                                mode="constant", anti_aliasing=False, 
                                                preserve_range=True)
        resized_yhat = resized_yhat.astype(np.uint8)
        colorized_img = colorize_segmentation(img, resized_yhat)
        out_path = os.path.join(result_dir, img_id + ".jpg")
        skimage.io.imsave(out_path, colorized_img)
        
def run_inference_from_web(image_initial, model_names, show_model_metadata=False):
    """
    Inputs: 
    image_initial: a pandas Dataframe queried from brtdevkit (where you can access the image s3 keys)
    model_names: a list containing model names from the soy_models dictionary

    Outputs, contained in separate directories:
    Downloaded NRG Images
    Model downloaded from local file or artifactory url
    Results images generated for each model
    """
    
    # Throw an error if the input DF is empty
    assert len(image_initial) > 0, 'Input DataFrame is empty. Nothing to analyze.'
    
    # Throw an error if more than one crop is represented in the DataFrame
    assert len(image_initial.crop_name.unique())==1, 'Multiple crop_names in DataFrame'
    
    today = date.today()
    out_name = f'{analysis_name}_{str(today).replace("-", "_")}'
    # '/home/williamroberts/code/brtdevkit/Projects/Inference Images/0528/CiaramaFazendas_4_SOYBEANS/nrgimages'
    out_dir = Path('') / Path(out_name)
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    out_dir.mkdir(parents=False, exist_ok=True)
    # Create directories for base images to analyze
    img_dir = Path('') / Path(out_name) / Path('nrg_images')
    img_dir.mkdir(parents=False, exist_ok=True)
    
    # Download Images To Infer
    download_images(image_initial, 'nrg', img_dir)
    
    for m in model_names:
        model_url = soy_models[m]
        model_name = os.path.basename(model_url)
        # Create out_dir and model_dir
        model_dir = out_dir 
        model_dir.mkdir(parents=False, exist_ok=True)
        result_dir = model_dir / Path(m+'_result_images')
        result_dir.mkdir(parents=False, exist_ok=True)
        
        if show_model_metadata:
            model_name = os.path.basename(model_url)
            metadata_name = 'metadata.json'
            metadata_url = os.path.join(os.path.dirname(model_url), 'metadata.json')
            metadata_path = model_dir / Path(metadata_name)
            download_from_url(metadata_url, metadata_path)
            with open(metadata_path, 'r') as f:
                model_metadata = json.load(f)
                print(m+' metadata.json:')
                print(model_metadata)
                print(' ')

        # TODO: Will need to change this to be available for multple crop_names, not just soy
        # Download model
        model_path = model_dir / Path(m + '.jit')
        if model_url[0:4] =='http':
            download_from_url(model_url, model_path)
        else:
            model_path = soy_models[m]
        run_inference_local(img_dir, model_path, result_dir)
        print(f'Inference completed for {m}')

def generate_metrics_from_labeled_images(df, model_names):
    """
    Takes a df of annotated images and generates metrics (fscore, precision, recall, etc)
    
    Inputs: 
    DatFrame of images with annotations.
    Models to use on the evaluation - can be one model or a list of models.
    
    Outputs:
    metrics_df - a well-formatted DatFrame of model evaluation results
    """
    # Set up directories and filepaths for model, images
    # i can probably  reuse my code from the multi-model inference functions
    model_name = 'allSoy_0'
    date = datetime.today().strftime("%Y%m%d_%H%M%s")
    out_dir = Path(f'{date}') # This is where all your images/labels will be stored
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    out_dir.mkdir(parents=False, exist_ok=True)

    model_dir = out_dir / Path('ShastaModel')
    model_dir.mkdir(parents=False, exist_ok=True)

    img_dir = out_dir / Path('Images')
    img_dir.mkdir(parents=False, exist_ok=True)

    preds_dir = out_dir / Path('Preds')
    preds_dir.mkdir(parents=False, exist_ok=True)

    model_url = f'https://artifactory.bluerivertech.com/artifactory/dev-shasta-models/jit/{model_name}/{model_name}.jit'
    
    # Download Images and Annotations
    
    
    # Generate Metrics
    
    # Format DataFrame for potting
    
    return mertics_df




# Set up directories and filepaths for model, images


In [None]:
import seaborn as sns

means = br.groupby(['operating_field_name' ,'model']).mean().reset_index()

pal = ['skyblue', 'olive']

plt.figure(figsize=(4, 8))
sns.boxplot(data = means, x='model', y='fscore', palette = pal)
plt.title('GridMetric Instance F1 Score')
plt.savefig('br_allfields_fscore.png')
plt.show()

In [None]:
def generate_metrics_plots(metrics_df, palette, x, y, hue=None):
    """
    Requires seaborn, matplotlib
    Take a dataframe of model evaluation metrics and generate plots
    
    Inputs: 
    metrics_df, a Dataframe of model evaluation metrics
    palette, a palette of colors to use in the plot
    x - The variable to display on the x axis
    y - The variable to display on the y axis
    hue - If desired, a variable to plot a comparison with
    
    Outputs:
    boxplot - a boxplot of the x- and y-axis variables
    barplot - a barplot of the x- and y-axis variables
    """
    
    # Assert that dataframe is formatted properly
    
    # generate boxplot and barplot 
    
    return boxplot, barplot