In [None]:
# from glob import glob
# import pickle
# import os
# import numpy as np
# a_file = open("Train_Data.pkl", "rb")
# output = pickle.load(a_file)
# a_file.close()
# os.makedirs("Train_Dataset", exist_ok=True)
# train_data = glob("Train_Data/*")
# train_dataset = []
# for td in train_data:
#     if td in output:
#         filename = td.replace("Train_Data/", "Train_Dataset/" + str(output[td]) + "_")
    
#     data = np.load(td)
#     np.save(filename, data)

In [None]:
import torch
from torch import nn
from torch.autograd import Variable
import torchvision
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.nn.functional import avg_pool2d, interpolate
import numpy as np
import math
from torch.utils.data import Dataset, DataLoader
import glob
import math
from functools import partial
from torchvision import transforms, utils
import random
import os
import requests
import xarray as xr
from datetime import datetime
try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url
    
from collections import OrderedDict
import socket
import time
import random
import rasterio as rio
import pvcz
from datetime import date, datetime, timedelta
import openet.ssebop as model
import ee


from os import path
device = "cpu"
if torch.cuda.is_available(): device = "cuda"

In [None]:
# ee.Authenticate()

In [None]:
ee.Initialize()

In [None]:
def conv1x1(in_channels, out_channels, stride = 1):
    return nn.Conv2d(in_channels,out_channels,kernel_size = 1,
                    stride =stride, padding=0,bias=False)

In [None]:
def conv3x3(in_channels, out_channels, stride = 1):
    return nn.Conv2d(in_channels,out_channels,kernel_size = 3,
        stride =stride, padding=1,bias=False)

In [None]:
class irnn_layer(nn.Module):
    def __init__(self,in_channels):
        super(irnn_layer,self).__init__()
        self.left_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0)
        self.right_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0)
        self.up_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0)
        self.down_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0)
        
    def forward(self,x):
        _,_,H,W = x.shape
        top_left = x.clone()
        top_right = x.clone()
        top_up = x.clone()
        top_down = x.clone()
        top_left[:,:,:,1:] = F.relu(self.left_weight(x)[:,:,:,:W-1]+x[:,:,:,1:],inplace=False)
        top_right[:,:,:,:-1] = F.relu(self.right_weight(x)[:,:,:,1:]+x[:,:,:,:W-1],inplace=False)
        top_up[:,:,1:,:] = F.relu(self.up_weight(x)[:,:,:H-1,:]+x[:,:,1:,:],inplace=False)
        top_down[:,:,:-1,:] = F.relu(self.down_weight(x)[:,:,1:,:]+x[:,:,:H-1,:],inplace=False)
        return (top_up,top_right,top_down,top_left)

In [None]:
class Attention(nn.Module):
    def __init__(self,in_channels):
        super(Attention,self).__init__()
        self.out_channels = int(in_channels/2)
        self.conv1 = nn.Conv2d(in_channels,self.out_channels,kernel_size=3,padding=1,stride=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(self.out_channels,self.out_channels,kernel_size=3,padding=1,stride=1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(self.out_channels,4,kernel_size=1,padding=0,stride=1)
        self.sigmod = nn.Sigmoid()
    
    def forward(self,x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.sigmod(out)
        return out

In [None]:
class SAM(nn.Module):
    def __init__(self,in_channels,out_channels,attention=1):
        super(SAM,self).__init__()
        self.out_channels = out_channels
        self.irnn1 = irnn_layer(self.out_channels)
        self.irnn2 = irnn_layer(self.out_channels)
        self.conv_in = conv3x3(in_channels,self.out_channels)
        self.relu1 = nn.ReLU(True)
        
        self.conv1 = nn.Conv2d(self.out_channels,self.out_channels,kernel_size=1,stride=1,padding=0)
        self.conv2 = nn.Conv2d(self.out_channels*4,self.out_channels,kernel_size=1,stride=1,padding=0)
        self.conv3 = nn.Conv2d(self.out_channels*4,self.out_channels,kernel_size=1,stride=1,padding=0)
        self.relu2 = nn.ReLU(True)
        self.attention = attention
        if self.attention:
            self.attention_layer = Attention(in_channels)
        self.conv_out = conv1x1(self.out_channels,1)
        self.sigmod = nn.Sigmoid()
    
    def forward(self,x):
        if self.attention:
            weight = self.attention_layer(x)
        out = self.conv1(x)
        top_up,top_right,top_down,top_left = self.irnn1(out)
        
        # direction attention
        if self.attention:
            top_up.mul(weight[:,0:1,:,:])
            top_right.mul(weight[:,1:2,:,:])
            top_down.mul(weight[:,2:3,:,:])
            top_left.mul(weight[:,3:4,:,:])
        out = torch.cat([top_up,top_right,top_down,top_left],dim=1)
        out = self.conv2(out)
        top_up,top_right,top_down,top_left = self.irnn2(out)
        
        # direction attention
        if self.attention:
            top_up.mul(weight[:,0:1,:,:])
            top_right.mul(weight[:,1:2,:,:])
            top_down.mul(weight[:,2:3,:,:])
            top_left.mul(weight[:,3:4,:,:])
        
        out = torch.cat([top_up,top_right,top_down,top_left],dim=1)
        out = self.conv3(out)
        out = self.relu2(out)
        mask = self.sigmod(self.conv_out(out))
        return mask

In [None]:
class convolutionalCapsule(nn.Module):
    def __init__(self, in_capsules, out_capsules, in_channels, out_channels, stride=1, padding=2,
                 kernel=5, num_routes=3, nonlinearity='sqaush', batch_norm=False, dynamic_routing='local', cuda=False):
        super(convolutionalCapsule, self).__init__()
        self.num_routes = num_routes
        self.in_channels = in_channels
        self.in_capsules = in_capsules
        self.out_capsules = out_capsules
        self.out_channels = out_channels
        self.nonlinearity = nonlinearity
        self.batch_norm = batch_norm
        self.bn = nn.BatchNorm2d(in_capsules*out_capsules*out_channels)
        self.conv2d = nn.Conv2d(kernel_size=(kernel, kernel), stride=stride, padding=padding,
                                in_channels=in_channels, out_channels=out_channels*out_capsules)
        self.dynamic_routing = dynamic_routing
        self.cuda = cuda
        self.SAM1 = SAM(self.in_channels,self.in_channels,1)

    def forward(self, x):
        batch_size = x.size(0)
        in_width, in_height = x.size(3), x.size(4)
        x = x.view(batch_size*self.in_capsules, self.in_channels, in_width, in_height)
        u_hat = self.conv2d(x) * self.SAM1(x)

        out_width, out_height = u_hat.size(2), u_hat.size(3)

        # batch norm layer
        if self.batch_norm:
            u_hat = u_hat.view(batch_size, self.in_capsules, self.out_capsules * self.out_channels, out_width, out_height)
            u_hat = u_hat.view(batch_size, self.in_capsules * self.out_capsules * self.out_channels, out_width, out_height)
            u_hat = self.bn(u_hat)
            u_hat = u_hat.view(batch_size, self.in_capsules, self.out_capsules*self.out_channels, out_width, out_height)
            u_hat = u_hat.permute(0,1,3,4,2).contiguous()
            u_hat = u_hat.view(batch_size, self.in_capsules, out_width, out_height, self.out_capsules, self.out_channels)

        else:
            u_hat = u_hat.permute(0,2,3,1).contiguous()
            u_hat = u_hat.view(batch_size, self.in_capsules, out_width, out_height, self.out_capsules*self.out_channels)
            u_hat = u_hat.view(batch_size, self.in_capsules, out_width, out_height, self.out_capsules, self.out_channels)


        b_ij = Variable(torch.zeros(1, self.in_capsules, out_width, out_height, self.out_capsules))
        if self.cuda:
            b_ij = b_ij.cuda()
        for iteration in range(self.num_routes):
            c_ij = F.softmax(b_ij, dim=1)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(5)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)


            if (self.nonlinearity == 'relu') and (iteration == self.num_routes - 1):
                v_j = F.relu(s_j)
            elif (self.nonlinearity == 'leakyRelu') and (iteration == self.num_routes - 1):
                v_j = F.leaky_relu(s_j)
            else:
                v_j = self.squash(s_j)

            v_j = v_j.squeeze(1)

            if iteration < self.num_routes - 1:
                temp = u_hat.permute(0, 2, 3, 4, 1, 5)
                temp2 = v_j.unsqueeze(5)
                a_ij = torch.matmul(temp, temp2).squeeze(5) # dot product here
                a_ij = a_ij.permute(0, 4, 1, 2, 3)
                b_ij = b_ij + a_ij.mean(dim=0)

        v_j = v_j.permute(0, 3, 4, 1, 2).contiguous()

        return v_j

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [None]:
class _DenseLayer(nn.Sequential):

    def __init__(self, num_caps, num_input_features, growth_rate, bn_size, drop_rate, actvec_size):
        super().__init__()
        self.concap1 = convolutionalCapsule(in_capsules=8, out_capsules=8, in_channels=8,
                                  out_channels=8,
                                  stride=1, padding=1, kernel=3, num_routes=3,
                                  nonlinearity='sqaush', batch_norm=True,
                                  dynamic_routing='local', cuda=True)
        
        self.concap2 = convolutionalCapsule(in_capsules=8, out_capsules=8, in_channels=8,
                                  out_channels=8,
                                  stride=1, padding=1, kernel=3, num_routes=3,
                                  nonlinearity='sqaush', batch_norm=True,
                                  dynamic_routing='local', cuda=True)
        

    def forward(self, x):
        new_features = self.concap1(x)
        new_features = self.concap2(new_features)
        return x + new_features

In [None]:
class _DenseBlock(nn.Sequential):

    def __init__(self, num_layers, num_caps, num_input_features, bn_size, growth_rate,
                 drop_rate, actvec_size):
        super().__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_caps, num_input_features,
                                growth_rate, bn_size, drop_rate, actvec_size)
            self.add_module('denselayer{}'.format(i + 1), layer)

In [None]:
class _Transition(nn.Sequential):

    def __init__(self, num_caps, in_vect, out_vect):
        super().__init__()
        self.skip = convolutionalCapsule(in_capsules=8, out_capsules=8,
                                     in_channels=8, out_channels=8,
                                  stride=1, kernel=1, padding=0, num_routes=3,
                                  nonlinearity='sqaush', batch_norm=True,
                                  dynamic_routing='local', cuda=True)
        
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
        
    def forward(self, x):
        out = self.skip(x)
        batch_size, num_caps = out.size(0), out.size(1)
        out = out.view(out.shape[0] * out.shape[1], out.shape[2], out.shape[3], out.shape[4])
        out = self.avgpool(out)
        out = out.view(batch_size, num_caps, out.shape[1], int(out.shape[2]), int(out.shape[3]))
        return out

In [None]:
class DenseNetModel(nn.Module):
    """Densenet-BC model class
    Args:
        growth_rate (int) - how many filters to add each layer (k in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self,
                 n_input_channels=5,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 growth_rate=32,
                 block_config=(6, 12, 24, 16),
                 num_init_features=64,
                 bn_size=32,
                 drop_rate=0,
                 num_classes=1,
                 actvec_size=8):

        super().__init__()
        
        self.num_init_features = num_init_features
        self.actvec_size = actvec_size
        self.num_caps = int(num_init_features/actvec_size)

        # First convolution
        self.input_features = [('conv1',
                          nn.Conv2d(n_input_channels,
                                    num_init_features,
                                    kernel_size=4,
                                    stride=2,
                                    padding=1,
                                    bias=False)),
                         ('norm1', nn.BatchNorm2d(num_init_features)),
                         ('relu1', nn.ReLU(inplace=True))]
        if not no_max_pool:
            self.input_features.append(
                ('pool1', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)))
        self.input_features = nn.Sequential(OrderedDict(self.input_features))
        self.features = nn.Sequential()
        # Each denseblock
        num_actvec = actvec_size
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                num_caps=self.num_caps,
                                num_input_features=num_actvec,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate,
                                actvec_size=actvec_size)
            self.features.add_module('denseblock{}'.format(i + 1), block)
            num_actvec = num_actvec + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_caps = self.num_caps, 
                                    in_vect=actvec_size, 
                                    out_vect=actvec_size)
                self.features.add_module('transition{}'.format(i + 1), trans)
                num_actvec = num_actvec // 2

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                
        self.metadata_network = torch.nn.Sequential(
            torch.nn.Linear(17, 64),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(64, 128)
        )

        # Linear layer
        self.classifier = nn.Linear(64 + 128, 1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x, metadata):
        input_features = self.input_features(x)
        input_features = input_features.view(input_features.shape[0], int(self.num_init_features/self.actvec_size), self.actvec_size, input_features.shape[-2], input_features.shape[-1])

        
        
        out = self.features(input_features)
        y = self.metadata_network(metadata)
        out = metadata[:, 0].unsqueeze(-1) + self.classifier(torch.cat((out.view(out.size(0), -1), y), dim=1))
        return out
    
#     def forward(self, x, metadata):
#         out = self.convolutions(x)
#         out = out.view(x.size(0), -1)
#         out_metadata = self.metadata_network(metadata)
#         out = metadata[:, 0] + self.fc(torch.cat((out, out_metadata), dim=1))

# #         out = ((torch.sigmoid(self.fc(torch.cat((out, out_metadata), dim=1))) * (self.range2 - self.range1)) + self.range1)
#         return out

In [None]:
def generate_model(model_depth, **kwargs):
    assert model_depth in [121, 169, 201, 264]

    if model_depth == 121:
        model = DenseNetModel(num_init_features=64,
                         growth_rate=4,
                         block_config=(6, 12, 24, 16),
                         **kwargs)
    elif model_depth == 169:
        model = DenseNetModel(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 32, 32),
                         **kwargs)
    elif model_depth == 201:
        model = DenseNetModel(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 48, 32),
                         **kwargs)
    elif model_depth == 264:
        model = DenseNetModel(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 64, 48),
                         **kwargs)

    return model

In [None]:
class Model(nn.Module):

    def __init__(self):
        super().__init__()

        self.SubModel = generate_model(121)
        
#     seq_len, batch, input_size
    def forward(self, x, y):
        out = self.SubModel(x, y)
        return out.flatten()  

In [None]:
class EvapoDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self):
        
        self.file_names = glob.glob("Train_Dataset/*")
        
        self.vegs = [
            "WAT",
            "ENF",
            "EBF",
            "DNF",
            "DBF",
            "MF",
            "CSH",
            "OSH",
            "WSA",
            "SAV",
            "GRA",
            "WET",
            "CRO",
            "URB",
            "CVM",
            "SNO",
            "BSV",
            "Missing Data" 
        ]
        self.clims = ['DFB', 'BWK', 'CFA', 'CWA', 'DWB', 'DFC', 'DFA', 'BSK', 'CSA', 'BSH']
        
        print("Dataset Length " + str(len(self.file_names)))
        

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        lon, lat, elev, veg, clim, geohash, year, month, day, cloud_coverage, pixel_coverage, true_et, pred_et = self.file_names[idx].split("_")[-13:]
        img = torch.from_numpy(np.load(self.file_names[idx]).astype(float))
        et = float(self.file_names[idx].split("_")[-1].replace(".npy", ""))
        openet = float(self.file_names[idx].split("/")[1].split("_")[0])
        date = "_".join(self.file_names[idx].split("_")[-7:-4])
        lat = float(self.file_names[idx].split("_")[-12])
        lon = float(self.file_names[idx].split("_")[-13])
        elev = np.array([float(self.file_names[idx].split("_")[-11])/8848.0])
        veg = torch.nn.functional.one_hot(torch.tensor(self.vegs.index(self.file_names[idx].split("_")[-10].upper())), num_classes=len(self.vegs))
        clim = torch.nn.functional.one_hot(torch.tensor(self.clims.index(self.file_names[idx].split("_")[-9].upper())), num_classes=len(self.clims))
        year = self.file_names[idx].split("_")[-7]
        month = self.file_names[idx].split("_")[-6]
        day = self.file_names[idx].split("_")[-5]
        
        date_time_obj = datetime.strptime(year + '_' + month + '_' + day, '%Y_%m_%d')
        day_of_year = date_time_obj.timetuple().tm_yday
        day_sin = torch.tensor([np.sin(2 * np.pi * day_of_year/364.0)])
        day_cos = torch.tensor([np.cos(2 * np.pi * day_of_year/364.0)])
        
        x_coord = torch.tensor([np.sin(math.pi/2-np.deg2rad(lat)) * np.cos(np.deg2rad(lon))])
        y_coord = torch.tensor([np.sin(math.pi/2-np.deg2rad(lat)) * np.sin(np.deg2rad(lon))])
        z_coord = torch.tensor([np.cos(math.pi/2-np.deg2rad(lat))])
        
        img = interpolate(img , size=32)[0]
        
        if img[20].mean() < 0:
            lon_img = img[20].clone()
            lat_img = img[19].clone()
        else:
            lat_img = img[20].clone()
            lon_img = img[19].clone()
            
        img[19] = lon_img
        img[20] = lat_img
        
        
        img[[1,2,3,4]] = ((img[[1,2,3,4]] *0.0000275)-0.2)
        img[7] = ((img[7] * 0.00341802) + 149.0)/400.0
        
        #Blue
        #Green
        #Red
        #NIR
        #LST
        #Lon
        #Lat
        
        output_img = img[[1,2,3,4,7,19,20]]
        
        rotations = random.randint(0, 3)
        if (rotations == 1):
            output_img = torch.rot90(output_img, 1, [1, 2])
        elif (rotations == 2):
            output_img = torch.rot90(output_img, 2, [1, 2])
        elif (rotations == 3):
            output_img = torch.rot90(output_img, 3, [1, 2])
                                     
        flip1 = random.randint(0, 1)
        if (flip1 == 1):
            output_img = torch.flip(output_img, (1,))
        
        flip2 = random.randint(0, 1)
        if (flip2 == 1):
            output_img = torch.flip(output_img, (2,))
        
        

        return output_img, et, veg, clim, day_sin, day_cos, x_coord, y_coord, z_coord, elev, date, lon, lat, self.file_names[idx], openet         


In [None]:
train_dataset = EvapoDataset()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=935, 
                                           num_workers=4, shuffle=True,
                                           pin_memory=True)

In [None]:
class SSEBop(torch.nn.Module):
    def __init__(self, device):
        super().__init__()    
        self.lc = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
        self.scale = 1000
        self.device = device
        self.dict = {}
        
    
    def _openet(self, date, lon, lat):
        date_1 = [datetime.strptime(d1, "%Y-%m-%d") for d1 in date]
        date_next_day = [str(d2 + timedelta(days=1))[:10] for d2 in date_1]
        ets = []
        
        for i1 in range(len(date)):
            
            if str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1]) in self.dict:
                ets.append(self.dict[str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1])])
            else:
                lc = self.lc.filterDate(date[i1], date_next_day[i1])
                site1point = ee.Geometry.Rectangle(lon[i1].item(), lat[i1].item(), lon[i1].item(), lat[i1].item())
                lc_r_poi = lc.getRegion(site1point, self.scale).getInfo()
                id_list = [x[0] for x in lc_r_poi[1:]]
                id_list.sort()
                print("THISONE", id_list)
                print("date", date[i1])
                print("lon", lon[i1])
                print("lat", lat[i1])
                landsat_img = ee.Image('LANDSAT/LC08/C02/T1_L2/' + id_list[-1])
                
#                 landsat_region = site1point.buffer(225)
#                 model_obj = model.Image.from_landsat_c2_sr(
#                     landsat_img.clip(site1point.buffer(point_size)), 
#                     tcorr_source='FANO',
#                     et_reference_source='IDAHO_EPSCOR/GRIDMET', 
#                     et_reference_band='etr', 
#                     et_reference_factor=0.85,
#                     et_reference_resample='nearest',
#                 )
#                 et = model_obj.et.reduceRegion(ee.Reducer.max(), landsat_region).getInfo()["et"]
#                 if et:
#                     ets.append(et)
#                     self.dict[str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1])] = et
#                 else:
#                     assert 0 == 1
#                     ets.append(0.0)
#                     self.dict[str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1])] = 0.0
                
                
                point_size = 225
                while True:
                    try:
                        landsat_region = site1point.buffer(225)
                        model_obj = model.Image.from_landsat_c2_sr(
                            landsat_img.clip(site1point.buffer(point_size)), 
                            tcorr_source='FANO',
                            et_reference_source='IDAHO_EPSCOR/GRIDMET', 
                            et_reference_band='etr', 
                            et_reference_factor=0.85,
                            et_reference_resample='nearest',
                        )
        #                 print("HERE1", model_obj.et.getInfo())
                        print("HERE2", model_obj.et.reduceRegion(ee.Reducer.max(), landsat_region).getInfo())
                        et = model_obj.et.reduceRegion(ee.Reducer.max(), landsat_region).getInfo()["et"]
                        if et:
                            ets.append(et)
                            self.dict[str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1])] = et
                            break
                        else:
                            point_size = point_size + 225
                            print("grow " + str(point_size))
                            continue
    #                             ets.append(0.0)
    #                             self.dict[str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1])] = 0.0
                    except:
                        point_size = point_size + 225
                        print("grow " + str(point_size))
        
        return ets
    
    def forward(self, date, lat, lon):
        #nir red lst
        # etr, elev, sph, srad, tmin, tmax, lat, doy
        
#         data = data
        date_str = [dstr.replace("_", "-") for dstr in date]
        
        return torch.tensor(self._openet(date_str, lon, lat), device=self.device)  

In [None]:
def test_quench(epoch, model, loader):
    with torch.no_grad():
        et_correct = 0
        counter = 0
        for img_seq, et, veg, clim, day_sin, day_cos, x_coord, y_coord, z_coord, elev, dat, lon, lat, _, openet in loader:
            img_seq = img_seq.to(device=device, dtype=torch.float32)
            et = et.to(device=device, dtype=torch.float32)
            veg = veg.to(device=device, dtype=torch.float32)
            clim = clim.to(device=device, dtype=torch.float32)
            day_sin = day_sin.to(device=device, dtype=torch.float32)
            day_cos = day_cos.to(device=device, dtype=torch.float32)
            x_coord = x_coord.to(device=device, dtype=torch.float32)
            y_coord = y_coord.to(device=device, dtype=torch.float32)
            z_coord = z_coord.to(device=device, dtype=torch.float32)
            elev = elev.to(device=device, dtype=torch.float32)
            lat = lat.to(device=device, dtype=torch.float32)
            lon = lon.to(device=device, dtype=torch.float32)
            openet = openet.to(device=device, dtype=torch.float32)
        
#             ssebop_ET = ssebop(dat, lat, lon)
            ssebop_ET = openet.reshape(openet.shape[0], -1)
        
            output = model(img_seq[:, 0:5], torch.cat((ssebop_ET, clim, day_sin, day_cos, x_coord, y_coord, z_coord, elev), dim=1))
        
            et_correct += (torch.sum(torch.abs((output-et))))
            counter += output.shape[0]

        
        return str(round(float(et_correct.sum() / counter), 4))

In [None]:
class TrainResAttnCap():

    def __init__(self, level=1, epochs=500, batch_size=512, torch_type=torch.float32):
        super(TrainResAttnCap, self).__init__()
        
        self.epochs = epochs
        self.batch_size = batch_size
        self.device = "cpu"
        if torch.cuda.is_available(): self.device = "cuda"
        self.torch_type = torch_type
        self.model_name = "ResAttnCap"
        
        self.mse = torch.nn.MSELoss()
        self.model = Model().to(self.device, dtype=torch.float32)
        
        self.train_dataset = EvapoDataset()
        
        self.dataset_size = len(self.train_dataset)
        self.indices = list(range(self.dataset_size))
        
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2, pin_memory=True)
        
        self.test_loader = torch.utils.data.DataLoader(self.train_dataset,batch_size=self.batch_size, num_workers=2, pin_memory=True)
        
        self.opt = torch.optim.Adagrad(self.model.parameters(), lr=0.01)
        self.sched = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10, gamma=0.999)
#         self.ssebop_model = SSEBop(device).to(device)
    
    def train(self):
        for epoch in range(1, 502):
            start_time = time.time()
            for ind, (img_seq, et, veg, clim, day_sin, day_cos, x_coord, y_coord, z_coord, elev, dat, lon, lat, _, openet) in enumerate(self.train_loader):
                img_seq = img_seq.to(device=self.device, dtype=torch.float32)
                true_et = et.to(device=self.device, dtype=torch.float32)
                clim = clim.to(device=self.device, dtype=torch.float32)
                day_sin = day_sin.to(device=self.device, dtype=torch.float32)
                day_cos = day_cos.to(device=self.device, dtype=torch.float32)
                x_coord = x_coord.to(device=self.device, dtype=torch.float32)
                y_coord = y_coord.to(device=self.device, dtype=torch.float32)
                z_coord = z_coord.to(device=self.device, dtype=torch.float32)
                elev = elev.to(device=self.device, dtype=torch.float32)
                lat = lat.to(device=self.device, dtype=torch.float32)
                lon = lon.to(device=self.device, dtype=torch.float32)
                openet = openet.to(device=device, dtype=torch.float32)
                self.opt.zero_grad()
                ssebop_ET = openet
                ssebop_ET = ssebop_ET.reshape(ssebop_ET.shape[0], -1)
                output = self.model(img_seq[:, 0:5], torch.cat((ssebop_ET, clim, day_sin, day_cos, x_coord, y_coord, z_coord, elev), dim=1))
                loss = self.mse(output, true_et)
                loss.backward()
                self.opt.step()
#                 print("===> " + str(ind + 1) + "/" + str(int(self.dataset_size/self.batch_size)) + ", " + str(loss))
            self.sched.step()
            test_accuracy = test_quench(epoch, self.model, self.test_loader) 
            print("Epoch " + str(epoch) + ", Test " + test_accuracy )
            if epoch % 50 == 0.0:
                torch.save(self.model, "Checkpoints/Quench_" + str(epoch) + ".pt" )
                print("Saved " + str(epoch) + " Epoch Model")
             



In [None]:
trainer = TrainResAttnCap()
trainer.train()