# Импорты + Подготовка Данных

In [1]:
import jax
from jax import numpy as jnp
import optax
import matplotlib.pyplot as plt
import ces
import balance
from dataclasses import dataclass

In [2]:
@dataclass
class YearStat:
    year: int #year's number
    Z: jnp.array #IO tables
    s: jnp.array #primal resorce's prices index
    Y: jnp.array #final consumption
    S: jnp.array #production capacities

In [3]:
tables = balance.read_NIOT('NIOTS/RUS_NIOT_nov16.xlsx')
vvp8 = jnp.array([24.8, 26.062, 27.312, 29.304, 31.407, 33.410, 36.135, 39.219, 
        41.277, 38.049, 39.762, 41.458, 62.486, 63.602, 64.072]) #ВВП 0-14 года в ценах 8ого
curs_abs = jnp.array([28.13, 29.18, 31.36, 30.67, 28.81, 28.31, 27.14, 25.55, 
            24.87, 31.77, 30.38, 29.39, 31.08, 31.90, 38.60]) #курс доллара
zp_rel = jnp.array([1.209, 1.199, 1.162, 1.109, 1.106, 1.126, 1.133, 1.172, 
          1.115, 0.965, 1.052, 1.028, 1.084, 1.048, 1.012]) #реальная зп в процентах от предыдущей
zp_abs = jnp.array([jnp.exp(jnp.sum(jnp.log(zp_rel[:year+1]))) for year in range(15)])

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [4]:
curs = [x/curs_abs[8] for x in curs_abs] 
vvp = [x/vvp8[8] for x in vvp8] #или наоборот нужно?
zp = [x/zp_abs[8] for x in zp_abs]

In [5]:
n, m = 33, 3
years = range(15)
Zs = [tables[year][:,:n] for year in years]
Ys = [jnp.sum(tables[year][:n,n:], axis = 1) for year in years]
DATA = [YearStat(Z = Zs[year], Y = Ys[year], S = jnp.sum(Zs[year], axis = 0), 
                 year = year, s = jnp.array([curs[year], zp[year], vvp[year]])) for year in years]

# Модель

In [12]:
class BoundedModule():
    def __init__(self, Z, rho):
        n, m = Z.shape
        self.n, self.m = m, n - m
        self.rho = rho
        self.W = ces.weights(Z, rho)
    
    def forward(self, lam, Y, S, s):
        p = ces.balance_prices(self.W, self.rho, s, lam)
        f = jnp.dot(Y,p[:n]) - jnp.dot(lam, S)
        return f
    
    def __call__(self, Y, S, s):
        key = jax.random.PRNGKey(10)
        lam = jax.random.uniform(key, (self.n,), minval=0.0, maxval=2.0)
        q0 = jnp.dot(Y,ces.balance_prices(self.W, self.rho, s, jnp.zeros((self.n,)))[:n])
        learning_rate = 10e-2
        
        optimizer = optax.adam(learning_rate)
        opt_state = optimizer.init(lam)
        for i in range(10000):
            f, grads = jax.value_and_grad(fun = self.forward, argnums=0)(lam, Y, S, s)
            norm = jnp.linalg.norm(grads * (lam > jnp.zeros((self.n,))))
            updates, opt_state = optimizer.update(-1 * grads, opt_state)
            lam = jnp.maximum(optax.apply_updates(lam, updates), jnp.zeros((self.n,)))
            if i % 50 == 0:
                q_act = f + jnp.dot(lam, S)
                print(f'step {i}, gradient norm:{jnp.around(a=norm, decimals=2)}, implied inflation:+{jnp.round(a = 100*(q_act / q0 - 1), decimals=-0)}%')
            if norm <= 1e-2: #нужны нормальные правила остановки
                q_act = f + jnp.dot(lam, S)
                imp_inflation = 100*(q_act / q0 - 1)
                print(f'step {i}, gradient norm:{jnp.around(norm, 2)}, implied inflation:+{imp_inflation:.2f}%')
                break
        return lam

In [13]:
key = jax.random.PRNGKey(0)
Z, rho = DATA[8].Z, jax.random.uniform(key, (n,), minval= -1.0, maxval=10)
module = BoundedModule(Z, rho)

In [16]:
data = DATA[9]
Y, S, s = data.Y, data.S, data.s
lam = module(Y, S, s)
print(f'lambda = {jnp.around(lam, 2)}')

step 0, gradient norm:67352.8359375, implied inflation:+199.0%
step 50, gradient norm:636.0899658203125, implied inflation:+56.0%
step 100, gradient norm:10.210000038146973, implied inflation:+57.0%
step 150, gradient norm:1.5999999046325684, implied inflation:+57.0%
step 200, gradient norm:0.25, implied inflation:+57.0%
step 234, gradient norm:0.009999999776482582, implied inflation:+56.58%
lambda = [0.         0.         0.         0.         0.64       0.
 0.         0.         0.         2.47       0.65999997 0.26
 1.17       0.         0.         0.         2.97       1.04
 0.         0.         0.         0.         0.         0.
 0.06       0.         0.         0.         0.         0.
 0.         0.         0.        ]
