In [None]:
import sys

In [None]:
sys.executable

In [None]:
!{sys.executable} -m pip install numpy pandas geopandas rasterio ws3

In [None]:
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
import pickle

In [None]:
gdf = gpd.read_file('dat/shp/tsa00.shp/stands.shp')

In [None]:
gdf.plot('age', figsize=(10, 10))

In [None]:
import os
from os import listdir
from os.path import isfile, join, dirname
import sys

In [None]:
import ws3

In [None]:
ws3

In [None]:
def rasterize_inventory(basenames, shp_path, tif_path, hdt_path, theme_cols, age_col, base_year,
                        age_divisor=1., round_age=1., cap_age=None, d=100., verbose=True):
    hdt = {}
    for bn in basenames:
        kwargs = {'shp_path':'%s/stands.shp' % shp_path(bn), 
                  'tif_path':'%s/inventory_init.tif' % tif_path(bn), 
                  'theme_cols':theme_cols, 
                  'age_col':age_col, 
                  'age_divisor':age_divisor,
                  'round_age':round_age,
                  'cap_age':cap_age,
                  'verbose':verbose,
                  'd':d}
        hdt[bn] = ws3.common.rasterize_stands(**kwargs)
        pickle.dump(hdt[bn], open('%s/hdt_%s.pkl' % (hdt_path, bn), 'wb'))
    return hdt


In [None]:
dat_path = 'dat'
basenames = ['tsa00']
shp_path = lambda bn: './dat/shp/tsa00'
tif_path = lambda bn: './dat/tif/tsa00'
hdt_path = lambda bn: './dat/hdt/hdt_tsa00.pkl'
snk_epsg = 3005 # BC Albers
tolerance = 10.
prop_names = [u'THLB', u'AU', u'LdSpp', u'Age2015', u'Shape_Area']
prop_types = [(u'theme0', 'str:10'),
              (u'theme1', 'str:1'),
              (u'theme2', 'str:5'), 
              (u'theme3', 'str:50'), 
              (u'age', 'int:5'), 
              (u'area', 'float:10.1')]
update_area_prop = 'area'
pixel_width = 100.
shp_name = 'stands'
age_col = 'age'
theme_cols = ['theme0', 'theme1', 'theme2', 'theme3']
compress = 'lzw'
dtype = rasterio.uint8
base_year = 2015

tvy_name = 'totvol'
horizon = 10
period_length = 10
max_age = 1000
oe_harvest = '_age >= 40 and _age <= 999'
action_params = {'harvest':{'oe':oe_harvest,
                            'mask':('?', '1', '?', '?'),
                            'is_harvest':True,
                            'targetage':0}}
yields_x_unit = 'years'
yields_period_length = 1

In [None]:
from ws3.common import rasterize_stands
hdt = {}
hdt['tsa00'] = rasterize_stands('./dat/shp/tsa00.shp', './dat/tif/tsa00/inventory_init.tif', theme_cols, age_col, d=pixel_width)

In [None]:
import pickle
pickle.dump(hdt['tsa00'], open('./dat/hdt/hdt_tsa00.pkl', 'wb'))

In [None]:
def compile_basecodes(hdt, basenames, theme_cols):
    import numpy as np
    bc1 = {bn:[list(np.unique(x)) for x in zip(*hdt[bn].values())] for bn in basenames}
    bc2 = [set() for _ in range(len(theme_cols))]
    for bn in basenames:
        for i in range(len(theme_cols)):
            bc2[i].update(bc1[bn][i])
    basecodes = [list(bc2[i]) for i in range(len(theme_cols))]
    return basecodes

In [None]:
basecodes = compile_basecodes(hdt, basenames, theme_cols)

In [None]:
def bootstrap_themes(fm, theme_cols=['theme0', 'theme1', 'theme2', 'theme3'], 
                     basecodes=[[], [], [], []], aggs=[{}, {}, {}, {}], verbose=False):
    for ti, t in enumerate(theme_cols):
        fm.add_theme(t, basecodes=basecodes[ti], aggs=aggs[ti])
    fm.nthemes = len(theme_cols)


def bootstrap_areas(fm, basenames, rst_path, yld_path, hdt, year=None, new_dts=True):
    import shutil
    print('bootstrap_areas', basenames)
    if not year:
        for bn in basenames:
            print('copying', '%s/inventory_init.tif' % rst_path(bn), 
                  '%s/inventory_%i.tif' % (rst_path(bn), fm.base_year))
            shutil.copyfile('%s/inventory_init.tif' % rst_path(bn), 
                            '%s/inventory_%i.tif' % (rst_path(bn), fm.base_year))
        year = fm.base_year
    for dt in fm.dtypes.values(): # yuck
        dt.reset_areas(0)
        dt.reset_areas()
    for bn in basenames:
        _sumarea = 0.
        with rasterio.open('%s/inventory_%i.tif' % (rst_path(bn), year), 'r') as src:
            pxa = pow(src.transform.a, 2) * 0.0001 # pixel area (hectares)
            bh, ba = src.read(1), src.read(2)
            n = 0
            for h, dt in hdt[bn].items():
                ra = ba[np.where(bh == h)] # match themes hash value
                if new_dts:
                    fm.dtypes[dt] = ws3.forest.DevelopmentType(dt, fm)
                for age in np.unique(ra):
                    _age = age
                    area = len(ra[np.where(ra == age)]) * pxa
                    _sumarea += area
                    fm.dtypes[dt].area(0, _age, area)
        print('bootstrap_areas', bn, year, pxa, _sumarea)

        
def bootstrap_yields(fm, yld_path, spcode='canfi_species', 
                        x_max=350, period_length=10., tvy_name='totvol', x_unit='years'):
    import math
    au_table = pd.read_csv('%s/au_table.csv' % yld_path).set_index('au_id')
    curve_table = pd.read_csv('%s/curve_table.csv' % yld_path)
    curve_points_table = pd.read_csv('%s/curve_points_table.csv' % yld_path).set_index('curve_id')
    #print(au_table.shape)

    # add constants (for bird AF coefficients)    
    c1 = fm.constants['birdaf_swvol_coeff'] =     +2.181e-3
    c2 = fm.constants['birdaf_hwvol_coeff'] =     -1.176e-2
    c3 = fm.constants['birdaf_age_coeff'] =       -8.235e-4
    c4 = fm.constants['birdaf_intercept_coeff'] = +7.594e-1

    swvol_ynames = ['s0105', 's0204', 's0101', 's0304', 's0100', 's0104']
    hwvol_ynames = ['s1201', 's1211']

    for au_id, au_row in au_table.iterrows():
        curve_id = au_row.unmanaged_curve_id # if not is_managed else au_row.managed_curve_id
        mask = ('?', '?', str(curve_id), '?')
        dt_keys = fm.unmask(mask)
        if not dt_keys: continue
        
        # add volume curve
        yname = 's%04d' % int(au_row.canfi_species)
        points = [(r.x, r.y) for _, r in curve_points_table.loc[curve_id].iterrows() if not r.x % period_length and r.x <= x_max]
        curve = fm.register_curve(ws3.core.Curve(yname, points=points, type='a', is_volume=True, xmax=fm.max_age, period_length=period_length))
        fm.yields.append((mask, 'a', [(yname, curve)]))
        fm.ynames.add(yname)
        for dtk in dt_keys: 
            fm.dtypes[dtk].add_ycomp('a', yname, curve)
        
        # add birdaf curve
        curve = curve * (c1 if yname in swvol_ynames else c2) + fm.common_curves['ages'] * c3 + c4
        points = [(x, math.exp(y)) for x, y in curve.points()]
        #####################################################################################################################
        # remove negative y values (else negative values will "cancel out" positive values when rolled up to landscape level)
        # TO DO: confirm that this is "the right thing to do" (i.e., consistent with statistical interpretation of bird AF)
        #points = [(x, max(0., y)) for x, y in curve.points()]
        #####################################################################################################################
        curve.add_points(points=points, compile_y=True) 
        yname = 'birdaf'
        fm.yields.append((mask, 'a', [(yname, curve)]))
        fm.ynames.add(yname)
        for dtk in dt_keys: 
            fm.dtypes[dtk].add_ycomp('a', yname, curve)
                
    mask = ('?', '?', '?', '?')
                
    # add total volume curve ###
    expr = '_SUM(%s)' % ', '.join(fm.ynames)
    fm.yields.append((mask, 'c', [(tvy_name, expr)]))
    fm.ynames.add(tvy_name)
    for dtk in fm.dtypes.keys(): fm.dtypes[dtk].add_ycomp('c', tvy_name, expr)  
    
    ## add softwood volume curve
    #yname = 'swvol'
    #expr = '_SUM(%s)' % ', '.join(swvol_ynames)
    #fm.yields.append((mask, 'c', [(yname, expr)]))
    #fm.ynames.add(yname)
    #for dtk in fm.dtypes.keys(): fm.dtypes[dtk].add_ycomp('c', yname, expr)
    
    # add hardwood volume curve
    #yname = 'hwvol'
    #expr = '_SUM(%s)' % ', '.join(hwvol_ynames)
    #fm.yields.append((mask, 'c', [(yname, expr)]))
    #fm.ynames.add(yname)
    #for dtk in fm.dtypes.keys(): fm.dtypes[dtk].add_ycomp('c', yname, expr)  
    
def bootstrap_actions(fm, action_params):
    for acode in action_params:
        ap = action_params[acode]
        mask, oe, is_harvest, targetage = ap['mask'], ap['oe'], ap['is_harvest'], ap['targetage']
        target = [(mask, 1.0, None, None, None, None, None)]
        fm.actions[acode] = ws3.forest.Action(acode, targetage=targetage, is_harvest=is_harvest)
        fm.oper_expr[acode] = {mask:oe}
        fm.transitions[acode] = {mask:{'':target}}
        for dtk in fm.unmask(mask):
            dt = fm.dtypes[dtk]
            dt.oper_expr[acode] = [oe]
            for age in range(1, fm.max_age):
                if not dt.is_operable(acode, 1, age): continue
                fm.dtypes[dtk].transitions[acode, age] = target

    
def bootstrap_forestmodel(basenames, model_name, model_path, base_year, yld_path, tif_path, horizon, 
                          period_length, max_age, basecodes, action_params, hdt,
                          add_null_action=True, tvy_name='totvol', compile_actions=True,
                          yields_x_unit='periods', yields_period_length=None, verbose=0):
    if not yields_period_length: yields_period_length = period_length
    from ws3.forest import ForestModel
    import math
    fm = ForestModel(model_name=model_name, 
                     model_path=model_path,
                     base_year=base_year,
                     horizon=horizon,     
                     period_length=period_length,
                     max_age=max_age)
    bootstrap_themes(fm, basecodes=basecodes)    
    #print('xxx', yld_path)
    bootstrap_areas(fm, basenames, tif_path, yld_path, hdt)
    bootstrap_yields(fm, yld_path, tvy_name=tvy_name, period_length=yields_period_length, x_unit=yields_x_unit)
    bootstrap_actions(fm, action_params)
    if add_null_action: fm.add_null_action()
    fm.compile_actions()
    fm.reset_actions()
    fm.initialize_areas()
    fm.grow()
    return fm

In [None]:
kwargs = {'basenames':basenames,
          'model_name':'foo',
          'model_path':dat_path,
          'base_year':int(base_year),
          'yld_path':dat_path,
          'tif_path':tif_path,
          'horizon':int(horizon),
          'period_length':int(period_length),
          'max_age':int(max_age),
          'basecodes':basecodes,
          'action_params':action_params,
          'hdt':hdt,
          'add_null_action':True,
          'tvy_name':tvy_name,
          'compile_actions':True,
          'yields_x_unit':yields_x_unit,
          'yields_period_length':int(yields_period_length),
          'verbose':1}

In [None]:
fm = bootstrap_forestmodel(**kwargs)

In [None]:
def _gen_scen_base(fm, basenames, name='base', util=0.85, param_funcs=None, target_scalefactors=None, harvest_acode='harvest', fire_acode='fire', 
                   tvy_name='totvol', toffset=0, obj_mode='min_harea', target_path='./input/targets.csv',
                   max_tp=2020, cacut=None, mask=None):
    fm.foo3 = target_scalefactors
    from functools import partial
    acodes = ['null', harvest_acode, fire_acode]  
    vexpr = '%s * %0.2f' % (tvy_name, util)
    if obj_mode == 'max_hvol':
        sense = ws3.opt.SENSE_MAXIMIZE 
        zexpr = vexpr
    elif obj_mode == 'min_harea':
        sense = ws3.opt.SENSE_MINIMIZE 
        zexpr = '1.'
    else:
        raise ValueError('Invalid obj_mode: %s' % obj_mode)
    if not param_funcs:
        target_scalefactors = {bn:1. for bn in basenames} if not target_scalefactors else target_scalefactors
        df_targets = pd.read_csv(target_path).set_index(['tsa', 'year'])
        param_funcs = {}
        param_funcs['cvcut'] = lambda bn, t: float(df_targets.loc[bn, t]['vcut']) * target_scalefactors[bn] if t <= max_tp else float(df_targets.loc[bn, max_tp]['vcut']) * target_scalefactors[bn]
        param_funcs['cabrn'] = lambda bn, t: float(df_targets.loc[bn, t]['abrn']) if t <= max_tp else float(df_targets.loc[bn, max_tp]['abrn'])
        param_funcs['cflw_acut_e'] = lambda bn, t: df_targets.loc[bn, t]['cflw_acut_e'] if t <= max_tp else df_targets.loc[bn, max_tp]['cflw_acut_e']
        param_funcs['cgen_vcut_e'] = lambda bn, t: df_targets.loc[bn, t]['cgen_vcut_e'] if t <= max_tp else df_targets.loc[bn, max_tp]['cgen_vcut_e']
        param_funcs['cgen_acut_e'] = lambda bn, t: df_targets.loc[bn, t]['cgen_vcut_e'] if t <= max_tp else df_targets.loc[bn, max_tp]['cgen_vcut_e']
        param_funcs['cgen_abrn_e'] = lambda bn, t: df_targets.loc[bn, t]['cgen_abrn_e'] if t <= max_tp else df_targets.loc[bn, max_tp]['cgen_abrn_e']
    coeff_funcs = {'z':partial(cmp_c_z, expr=zexpr)}
    coeff_funcs.update({'cacut_%s' % bn:partial(cmp_c_caa, expr='1.', acodes=[harvest_acode], mask=(bn, '?', '?', '?')) 
                        for bn in basenames})
    coeff_funcs.update({'cvcut_%s' % bn:partial(cmp_c_caa, expr=vexpr, acodes=[harvest_acode], mask=(bn, '?', '?', '?')) 
                        for bn in basenames})
    if fire_acode:
        for i in ['0', '1']:
            coeff_funcs.update({'cabrn-thlb%s_%s' % (i, bn):partial(cmp_c_caa, expr='1.', acodes=[fire_acode], mask=(bn, i, '?', '?')) 
                                for bn in basenames})
    T = fm.periods# [fm.base_year+(t-1)*fm.period_length for t in fm.periods]
    cflw_e, cgen_data = {}, {}
    #foo = {bn:{t:(bn, t+toffset) for t in T} for bn in basenames}
    #print(T)
    #assert False
    cflw_ebn = {bn:({t:param_funcs['cflw_acut_e'](bn, fm.base_year+(t-1)*fm.period_length+toffset) for t in T}, fm.periods[-1]) for bn in basenames}
    cflw_e.update({'cacut_%s'%bn:cflw_ebn[bn] for bn in basenames})
    for bn in basenames:
        #print(df_targets.loc[bn])
        cgen_data.update({'cvcut_%s' % bn:{'lb':{t:param_funcs['cvcut'](bn, fm.base_year+(t-1)*fm.period_length+toffset) * fm.period_length *
                                                 (1. - param_funcs['cgen_vcut_e'](bn, fm.base_year+(t-1)*fm.period_length+toffset))
                                               for t in T}, 
                                         'ub':{t:param_funcs['cvcut'](bn, fm.base_year+(t-1)*fm.period_length+toffset) * fm.period_length for t in T}}})
        if cacut:
            cgen_data.update({'cacut_%s' % bn:{'lb':{t:param_funcs['cacut'](bn, fm.base_year+(t-1)*fm.period_length) * fm.period_length *
                                                     (1. - param_funcs['cgen_acut_e'](bn, fm.base_year+(t-1)*fm.period_length)) for t in T}, 
                                               'ub':{t:param_funcs['cacut'](bn, fm.base_year+(t-1)*fm.period_length) * fm.period_length for t in T}}})
        if fire_acode:
            for i in ['0', '1']:
                p = fm.inventory(0, mask='? %s ? ?' % i) / fm.inventory(0, mask='? ? ? ?')
                cgen_data.update({'cabrn-thlb%s_%s' % (i, bn):{'lb':{t:param_funcs['cabrn'](bn, fm.base_year+(t-1)*fm.period_length) * p * fm.period_length *
                                                                     (1. - param_funcs['cgen_abrn_e'](bn, fm.base_year+(t-1)*fm.period_length)) for t in T}, 
                                                               'ub':{t:param_funcs['cabrn'](bn, fm.base_year+(t-1)*fm.period_length) * p * fm.period_length for t in T}}})
                #fm.cgen_data = cgen_data
                #assert False
    #print(cflw_e)
    fm._tmp = {}
    fm._tmp['param_funcs'] = param_funcs
    fm._tmp['cgen_data'] = cgen_data
    return fm.add_problem(name, coeff_funcs, cflw_e, cgen_data=cgen_data, acodes=acodes, sense=sense, mask=mask)

In [None]:
def schedule_harvest_areacontrol(fm, period=1, acode='harvest', util=0.85, 
                                 target_masks=None, target_areas=None, target_scalefactors=None,
                                 mask_area_thresh=0.,
                                 verbose=0):
    fm.reset_actions()
    if not target_areas:
        if not target_masks: # default to AU-wise THLB 
            au_vals = []
            au_agg = []
            for au in fm.theme_basecodes(2):
                mask = '? 1 %s ?' % au
                masked_area = fm.inventory(0, mask=mask)
                if masked_area > mask_area_thresh:
                    au_vals.append(au)
                else:
                    au_agg.append(au)
                    if verbose > 0:
                        print('adding to au_agg', mask, masked_area)
            if au_agg:
                fm._themes[2]['areacontrol_au_agg'] = au_agg 
                au_vals.append('areacontrol_au_agg')
            target_masks = ['? 1 %s ?' % au for au in au_vals]
        #print(target_masks)
        #assert False
        target_areas = []
        for i, mask in enumerate(target_masks): # compute area-weighted mean CMAI age for each masked DT set
            masked_area = fm.inventory(0, mask=mask, verbose=verbose)
            if not masked_area: continue
            r = sum((fm.dtypes[dtk].ycomp('totvol').mai().ytp().lookup(0) * fm.dtypes[dtk].area(0)) for dtk in fm.unmask(mask))
            r /= masked_area
            #awr = []
            #dtype_keys = fm.unmask(mask)
            #for dtk in dtype_keys:
            #    dt = fm.dtypes[dtk]
            #    awr.append(dt.ycomp('totvol').mai().ytp().lookup(0) * dt.area(0))
            #r = sum(awr)  / masked_area
            asf = 1. if not target_scalefactors else target_scalefactors[i]  
            ta = (1/r) * masked_area * asf
            target_areas.append(ta)
    for mask, target_area in zip(target_masks, target_areas):
        if verbose > 0:
            print('calling areaselector', period, acode, target_area, mask)
        fm.areaselector.operate(period, acode, target_area, mask=mask, verbose=verbose)
    sch = fm.compile_schedule()
    return sch

In [None]:
#for p in fm.periods:
#    schedule_harvest_areacontrol(fm, period=p, verbose=True)
sch = schedule_harvest_areacontrol(fm, period=1, verbose=True)

In [None]:
{p:fm.compile_product(p, '1.') for p in fm.periods}

In [None]:
fm.inventory(1)

In [None]:
def sda(fm, basenames, time_step, tif_path, hdt, acode_map=None, nthresh=0, sda_mode='randpxl', verbose=False):
    from pathlib import Path
    from ws3.spatial import ForestRaster
    from ws3.common import hash_dt
    import os
    if acode_map is None:
        acode_map = {'harvest':'projected_harvest'}
    def cmp_fr_kwargs(bn):
        tmp_path = os.path.split(tif_path(bn))[0]
        _tif_path = '%s/%s' % (tmp_path, bn)
        if not Path(_tif_path).exists():
            Path(_tif_path).mkdir()
        fr_kwargs = {'hdt_map':hdt[bn], 
                     'hdt_func':hash_dt, 
                     'src_path':'%s/inventory_%i.tif' % (tif_path(bn), fm.base_year),
                     'snk_path':_tif_path,
                     'acode_map':acode_map,
                     'forestmodel':fm,
                     'horizon':fm.horizon,
                     'period_length':fm.period_length,
                     'time_step':time_step,
                     'base_year':fm.base_year,
                     'piggyback_acodes':{}}
        return fr_kwargs
    for bn in basenames:
        print('SDA for TSA', bn)
        #mask = (bn, '?', '?', '?')
        mask = None
        fr = ForestRaster(**cmp_fr_kwargs(bn))
        fr.allocate_schedule(mask=mask, verbose=verbose, sda_mode=sda_mode)
        fr.cleanup()

In [None]:
acode_map = {acode:'projected_%s' % acode for acode in fm.actions.keys() if acode not in ['null']}

In [None]:
sda(fm, basenames, 1, tif_path, hdt, acode_map, sda_mode='randblk', verbose=3)