In [1]:
import os
import h5py
import numpy as np
import pandas as pd
import geopandas as gpd
import seaborn as sbn
import datetime as dt
import matplotlib.pyplot as plt
from tqdm.notebook import  tqdm, trange

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda')

In [3]:
class MyNet(nn.Module):

    def __init__(self, w_layers, pred_layers):

        super().__init__()
              
        # Weaher variables
        w_layers_list = nn.ModuleList()
        for i, o, k, d in w_layers:
            w_layer = nn.Sequential(
                nn.Conv1d(i, o, k),
                nn.AvgPool1d(2),
                nn.LeakyReLU(),
                nn.BatchNorm1d(o),
                nn.Dropout(d)
            )
            w_layers_list.append(w_layer)
        w_layers_list.append(nn.AdaptiveAvgPool1d(1))
        self.w_layers = nn.Sequential(*w_layers_list)

        
        # Management variables
        pred_layers_list = nn.ModuleList()
        for i, o, d in pred_layers:
            pred_layer = nn.Sequential(
                nn.Linear(i, o),
                nn.LeakyReLU(),
                nn.BatchNorm1d(o),
                nn.Dropout(d)
            )
            pred_layers_list.append(pred_layer)
        pred_layers_list.append(nn.Linear(o, 1))
        self.pred_layers = nn.Sequential(*pred_layers_list)
        
   
    def forward(self, Ws):
        
        feat = self.w_layers(Ws).view(Ws.shape[0], -1)
        pred = self.pred_layers(feat)
        return (torch.tanh(pred))


In [4]:
def transform(w):
    ws = np.array([[[5e4,50,50,5,100.0]]])
    w = w / ws
    w = np.moveaxis(w, 1, 2)
    wd = np.linspace(-0.9,2.1,300)[None,None]
    wd = wd.repeat(len(w), 0)
    w = np.concatenate([w, wd], 1)
    w = torch.tensor(w, dtype=torch.float, device = device)
    return(w)

def back_transform(w):
    w = w[:,:-1].cpu().data.numpy()
    w = np.moveaxis(w, 2, 1)
    ws = np.array([[[5e4,50,50,5,100.0]]])
    w = w * ws
    return(w[0])


def get_adv(x, eps = 1):
    xo = x.clone()
    x.requires_grad = True

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam([x], lr=0.01)

    for i in range(10):

        # Limpa os gradientes
        optimizer.zero_grad()

        # Obtem o output
        outputs = model(x)

        # Calcula a perda pela loss function
        loss = -outputs.mean()

        # Use an l2 penalty:
        loss += criterion(w_std * xo, w_std * x)/eps

        # Obtem os gradientes
        loss.backward()

        # Atualiza os parâmetros
        optimizer.step()

        # Clip to the valid range of values:
        x.data = torch.clamp(x.data, 0, 1)
        x.data[:,-1] = xo[:,-1]

    return(x)

In [5]:
d = 0.0
w_layers =  [[6,12,3,d],[12,15,5,d],[15,20,7,d],[20,25,5,d],[25,100,3,d]]
pred_layers = [[100,50,d],[50,50,d], [50,25,d]]


# model = MyNet(w_layers, pred_layers)
# model = model.to(device)
    
# model_file_name = '../data/model_cnn_rnd_0.1.pth'
# model.load_state_dict(torch.load(model_file_name, map_location=device))
# model.eval()
# print('Model OK')

In [6]:
# Values used to scale the weather data:
ws = np.array([[[5e4,50,50,5,100.0]]])
w_std = np.array([0.08034062, 0.08567557, 0.08080445, 0.14946058, 0.10540102,0.0])
w_std = torch.tensor(w_std, device = device, dtype = torch.float)[:,None]

In [7]:
ydf = pd.read_hdf('../data/PSCE_TILE.h5', key = 'SIM').set_index('SIM')
ydf['Yield'] = (ydf.TWSO/2e4)

In [8]:
epsilon = 0.01

In [9]:
wdir = '/home/rodrigo7/Apsim_test/MASAGRO/DAYMET_TILE'
pxy = np.stack(np.meshgrid(np.arange(40), np.arange(40)), -1).reshape(-1, 2)
wfiles = [f'{wdir}/DAYMET_9584_{px:02d}_{py:02d}.csv' for px, py in pxy]

In [10]:
# wridx = []
# for wfile in wfiles:   
#     ridx = 10 * np.arange(8,10) + np.random.randint(0, 10, 2)
#     wridx.append(ridx)
# wridx = np.array(wridx)
# np.save('../data/adv_idx.npy', wridx)
wridx = np.load('../data/adv_idx.npy')

In [None]:
train_methods = ['none', 'rnd', 'adv']

for REP in range(1, 6):
    for train_method in train_methods:
        for PCT in [1,5]:
            model = MyNet(w_layers, pred_layers)
            model = model.to(device)
            model_file_name = f'../data/model_cnn_{train_method}_{PCT}_{REP}.pth'
    
            model.load_state_dict(torch.load(model_file_name, map_location=device))
            model.eval()

            yl = []
            wfile = wfiles[10]
            for wfile, ridx in zip(tqdm(wfiles), wridx):   
                site = os.path.basename(wfile).replace('.csv', '')

                w = pd.read_csv(wfile, skiprows = 13)
                w.DAY = pd.to_datetime(w.DAY, format = '%Y%m%d').dt.date
                sydf = ydf.loc[site]

                wwf = []
                for crop_start_date in sydf.SIM_DATE.values[ridx]:
                    cs_date = np.where(w.DAY == crop_start_date)[0][0]
                    wrng = slice(cs_date-90, cs_date+210)
                    w_SIM = w.iloc[wrng].copy()
                    w_seed = transform(w_SIM.iloc[:,[1,2,3,4,6]].values)
                    y_pred = model(w_seed).data.cpu().numpy()
                    w_adv = get_adv(w_seed, epsilon)
                    y_adv = model(w_adv).data.cpu().numpy()
                    w_adv = back_transform(w_adv)
                    yl.append([y_pred, y_adv])


                    w_SIM.iloc[:,[1,2,3,4,6]] = w_adv
                    w_SIM.VAP = np.clip(w_SIM.VAP, 0.06, 199.3)

                    plant_date = format(crop_start_date, '%Y%m%d')
                    save_file = wfile.replace('.csv', f'_{plant_date}_opt_{train_method}_05.csv')
                    with open(save_file, 'w') as sf:
                        with open(wfile) as f:
                            for r in range(14):
                                sf.writelines(f.readline())

                    w_SIM.DAY = pd.to_datetime(w_SIM.DAY).dt.strftime('%Y%m%d')
                    w_SIM.to_csv(save_file, na_rep = 'NaN', mode = 'a', float_format = '%.3f', header = False, index = False)

            ynp = 2e4 * np.array(yl)[:,:,0,0]
            np.save(f'../data/y_pred_opt_{train_method}_05.npy', ynp)


HBox(children=(FloatProgress(value=0.0, max=1600.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1600.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1600.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1600.0), HTML(value='')))