In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
device = torch.device('cuda')

In [2]:
from sklearn.preprocessing import StandardScaler
from scipy.interpolate import NearestNDInterpolator
from mgwr.kernels import Kernel

In [3]:
cotton_price = 0.4
nitrogen_cost = 1.0
nitrogen_ratio = nitrogen_cost/cotton_price
sq_rate = 100
sq_yield = 4500
rst_b2g = -0.02

In [4]:
trial = np.load('../data/Trial_Design.npy')
trial_names = np.load('../data/Trial_Design_names.npy', allow_pickle=True)
trial_coords = np.load('../data/Trial_Design_coords.npy')
trial_coords_idx = np.int0(np.floor(trial_coords/3))
n = trial_coords.shape[0]

rst_sim = np.load('../data/Trial_sim.npy')

In [5]:
b0_true = rst_sim[:100]
b1_true = rst_sim[100:]

rst_optr = sq_rate * (1 + 0.25 * b1_true) 
rst_optr = np.clip(rst_optr, 0, 200)

rst_b1 = -2 * rst_b2g * rst_optr + nitrogen_ratio
rst_b0 = sq_yield * (1 + 0.05 * b0_true)
rst_b0 = rst_b0 - (rst_b1 * rst_optr + rst_b2g * rst_optr **2) 

trial_idx = [i for i, n in enumerate(trial_names) if not 'rep' in n]

In [6]:
for idx in tqdm(trial_idx):
    exp_name = trial_names[idx]

    rep = trial[6 * (idx//6)].copy()
    rep[np.isnan(trial[idx])] = 0
    ridx = np.unique(np.insert(np.unique(rep), 0, 0))
    repdm = rep[:,:,None] == ridx[None, None]
    repf = repdm.argmax(-1).astype('float')
    ridx = np.stack(np.where(repdm[:,:,1:]))
    ridx = np.stack([ridx[:2,ridx[2] == i] for i in np.unique(ridx[2])], 1)
    
    irep = repf.reshape(-1)
    z = irep[irep > 0]
    xy = trial_coords[irep > 0]
    interp = NearestNDInterpolator(xy, z)
    repf = interp(trial_coords) - 1
    repf = repf.reshape(rep.shape).astype('int')

    trial_rst = np.nan_to_num(trial[idx])
    inp_rate = sq_rate + 12.5 * trial_rst
    rst_yield_obs =  rst_b0 + rst_b1 * inp_rate + rst_b2g * inp_rate**2

    y = rst_yield_obs.reshape(-1,1)
    X = inp_rate.reshape(-1,1)

    y_std = StandardScaler().fit(y)
    y = y_std.transform(y).reshape(rst_yield_obs.shape)

    X_std = StandardScaler().fit(X)
    X = X_std.transform(X).reshape(inp_rate.shape)

    p = np.array([0,1,2])[:,None,None]
    X = X[None,None] ** (np.ones((3, *X.shape)) * p)

    c = y[:,None].repeat(3,1)
    c[:,1:] = 0
    c = torch.tensor(c[:,:,ridx[0], ridx[1]], device = device, requires_grad = True)

    X = torch.tensor(X[:,:,ridx[0], ridx[1]], device = device)
    y = torch.tensor(y[:,ridx[0], ridx[1]], device = device)

    criterion = nn.MSELoss()
    learning_rate = 0.1
    optimizer = torch.optim.Adam((c,), learning_rate)

    for i in trange(100, desc = exp_name):
        # Clean the gradients
        optimizer.zero_grad()

        # Enforce single coefficient per replicate:
        cm = c.mean(-1)[:,:,:,None].repeat(1,1,1,c.shape[-1])

        # Enforce fixed (global) parameter for the second order effect:
        cm[:,2] = cm[:,2].mean()

        # Predict the yield based on the spatial parameters and the rates applied:
        pred = (X * cm).sum(1)

        # Calc the loss using the MSE between actual and predicted yield:
        loss = criterion(pred, y)

        # Also minimize the distance between the raw parameters and their spatially smothed version:
        loss += criterion(cm, c)

        # Calc the gradients:
        loss.backward()

        # Update model parameters:
        optimizer.step()

    X_testr = torch.linspace(0, 200, 200, device = device)
    X_cost = X_testr[:,None,None] * nitrogen_cost
    X_test = (X_testr - X_std.mean_[0]) /  X_std.scale_[0]
    p = torch.tensor([[0,1,2]], device = device)[:,:,None,None]
    X_test = X_test[:,None,None,None] ** (torch.ones_like(c)[0][None] * p)

    best_idx = []
    for i in range(len(c)):
        y_test = (X_test * cm[i]).sum(1)
        y_test = y_std.mean_[0] + (y_std.scale_[0] * y_test)
        y_net_pred = y_test * cotton_price - X_cost
        best_idx.append(y_net_pred.argmax(0))

    best_idx = torch.stack(best_idx)
    rst_optr_pred = X_testr[best_idx].cpu().numpy()
    rst_optr_pred = rst_optr_pred.mean(-1)[:,repf]
    np.save(f'../data/{exp_name}_lm.npy', rst_optr_pred)


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='size03_p001', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size03_p005', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size03_p010', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size03_p050', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size03_p100', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size06_p001', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size06_p005', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size06_p010', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size06_p050', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size06_p100', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size15_p001', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size15_p005', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size15_p010', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size15_p050', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size15_p100', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size30_p001', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size30_p005', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size30_p010', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size30_p050', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size30_p100', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size60_p001', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size60_p005', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size60_p010', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size60_p050', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size60_p100', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size01_p001', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size01_p005', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size01_p010', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size01_p050', style=ProgressStyle(description_width='init…




HBox(children=(FloatProgress(value=0.0, description='size01_p100', style=ProgressStyle(description_width='init…



