In [1]:
import numpy as np
import pandas as pd
import gcamwrapper as gw
import jax.numpy as jnp
from jax.numpy.linalg import norm as jnorm
import time

In [2]:
class BEVDeploymentTuning:
    def __init__(self, target_data_fn, year_limit=None):
        if year_limit is None:
            year_limit = 2100
        # output - to comp
        self.service_query = gw.get_query("transportation", "service")
        # input - can change
        self.sw_query = 'world/region{region@name}/sector[+NamedFilter,StringRegexMatches,^trn_]/subsector{subsector@name}/technology{tech@name}/period{year@year}/real-share-weight'
        self.target_data = pd.read_csv(target_data_fn)
        self.target_data = self.target_data[self.target_data['year'] >= 2025]
        self.target_data = self.target_data[self.target_data['year'] <= year_limit]
    
    def initialize(self, g):
        trn_sw = g.get_data(self.sw_query)
        self.base_sw = self.target_data[['region', 'sector', 'subsector', 'technology', 'year']].copy().merge(trn_sw.rename(columns={"period": "year"}), on=['region', 'sector', 'subsector', 'technology', 'year'])
    
    def get_initial_tuning_params(self):
        return self.base_sw['real-share-weight'].copy()
    
    def get_num_tuning_output(self):
        return len(self.target_data['target'])
    
    def set_tuning_params(self, g, new_values):
        new_sw = self.base_sw.copy()
        new_sw['real-share-weight'] = new_values
        #print(new_sw)
        #print(f"length of new_values: {len(new_sw['real-share-weight'])}")
        g.set_data(new_sw, self.sw_query)
    
    def get_tuning_deviation(self, g):
        new_service = g.get_data(self.service_query)
        service_compare = self.target_data[['region', 'sector', 'subsector', 'technology', 'year', 'target']].copy().merge(new_service, on=['region', 'sector', 'subsector', 'technology', 'year'])
        service_compare['error'] = (service_compare.target - service_compare['physical-output']) / service_compare.target
        return service_compare['error']

In [3]:
def backtrack(F, tuner, g, x, x_new, fx, fx_new, fx_norm, fx_norm_new, max_iter=5):
    print("backtracking")
    dx = jnp.array(x_new) - jnp.array(x)
    step_len = 0.5
    x_curr = x_new
    fx_curr = fx_new
    fx_norm_curr = fx_norm_new
    fx_norm_old = None
    for i in range(max_iter):
        print(f"fx_norm_curr: {fx_norm_curr}, fx_norm: {fx_norm}")
        if fx_norm_curr < fx_norm:
            return x_curr, fx_curr, fx_norm_curr
        step_len = step_len / 2.0
        x_curr = x + step_len * dx
        x_curr = jnp.array(x_curr)
        x_curr = jnp.where(x_curr < 0, 0.0, x_curr)
        fx_curr = F(np.asarray(x_curr), tuner, g)
        fx_norm_curr = jnorm(fx_curr, ord=2) / len(fx_curr)
        if fx_norm_old is not None and fx_norm_curr > fx_norm_old:
            print("Backtrack failed, likely wrong direction, returning previous value")
            break
        fx_norm_old = fx_norm_curr
    if i == max_iter - 1:
        print("Max backtrack iterations reached, returning previous value")
    return x_curr, fx_curr, fx_norm_curr

In [4]:
def broyden(F, x0, tuner, g, J=None, tol=0.1, max_iter=100):
    if J is None:
        J = jnp.eye(len(x0)) * -1
    J_inv = jnp.linalg.inv(J)
    fx = F(x0, tuner, g)
    x = x0
    old_norm = jnorm(fx, ord=2) / len(fx)
    all_norms = [old_norm]
    for i in range(max_iter):
        x_new = x - (J_inv @ fx)
        x_new = jnp.array(x_new)
        x_new = jnp.where(x_new < 0, 0.0, x_new)
        fx_new = F(np.asarray(x_new), tuner, g)
        norm = jnorm(fx_new, ord=2) / len(fx_new)
        if i > 0 and norm > old_norm:
            x_new, fx_new, norm = backtrack(F, tuner, g, x, x_new, fx, fx_new, old_norm, norm)
        xstep = jnp.array(x_new-x)
        fxstep = jnp.array(fx_new -fx)
        if norm < tol:
            print(all_norms)
            for x in x_new:
                print(x)
            return x_new
        fxstep = fxstep - J @ xstep
        dx2 = jnp.dot(xstep, xstep)
        fxstep = fxstep / dx2
        J_new = J + fxstep * jnp.transpose(xstep)
        J_inv = jnp.linalg.inv(J_new)
        x = x_new
        fx = fx_new
        J = J_new
        all_norms.append(norm)
        old_norm = norm
    print(all_norms)
    raise Exception(f"Did not converge within {max_iter} iterations")

In [5]:
def F(x, tuner, g):
    tuner.set_tuning_params(g, x.astype(np.float64))
    g.run_period(g.convert_year_to_period(2025))
    g.run_period(g.convert_year_to_period(2040))
    ans = tuner.get_tuning_deviation(g)
    ans = jnp.array(ans.to_numpy())
    return ans

In [None]:
bev_target = BEVDeploymentTuning("../data/bev_target_us.csv", year_limit=2040)
g = gw.Gcam("config_minimal.xml", "../data/")
g.run_period(g.convert_year_to_period(2040))

bev_target.initialize(g)
initial_bev_sw = bev_target.get_initial_tuning_params()

rand_gen = np.random.default_rng(1919)
x_init = initial_bev_sw * rand_gen.random((len(initial_bev_sw))) * 2.0

start_time = time.time()
ans = broyden(F, x_init, bev_target, g)
print("Broyden took ", time.time() - start_time, " to run")
print(ans)

with open('x_new.txt', 'w') as file:
    for element in ans:
        file.write(f"{element}\n")

Running GCAM model code base version 7.1 revision gcam-v7.1

Configuration file:  config_minimal.xml
Parsing input files...
Parsing ./input/gcamdata/xml/no_climate_model.xml scenario component.
Parsing ./input/gcamdata/xml/socioeconomics_gSSP2.xml scenario component.
Parsing ./input/gcamdata/xml/transportation_UCD_CORE.xml scenario component.
Parsing ./input/gcamdata/debug_test.xml scenario component.
Parsing ./input/solution/cal_broyden_config.xml scenario component.
XML parsing complete.
Starting new scenario: Reference
SEVERE ERROR:renewable in USA is not related to any other activities.
Starting a model run. Running period 9
Model run beginning.
Period 0: 1975
Model solved with last period's prices.

Period 1: 1990
Model solved with last period's prices.

Period 2: 2005
Model solved with last period's prices.

Period 3: 2010
Model solved with last period's prices.

Period 4: 2015
Model solved with last period's prices.

Period 5: 2020
Model solved with last period's prices.

Period